diff --git a/tools/bazel.rc b/.bazelrc similarity index 69% rename from tools/bazel.rc rename to .bazelrc index 660e3d328038b618fefdf96d60863941d3a46edd..1945078789dcd48603ceb322c34ab2cd5af5eb59 100644 --- a/tools/bazel.rc +++ b/.bazelrc @@ -24,19 +24,28 @@ build --define framework_shared_object=true # Please note that MKL on MacOS or windows is still not supported. # If you would like to use a local MKL instead of downloading, please set the # environment variable "TF_MKL_ROOT" every time before build. -build:mkl --define=using_mkl=true +build:mkl --define=build_with_mkl=true --define=enable_mkl=true build:mkl -c opt # This config option is used to enable MKL-DNN open source library only, # without depending on MKL binary version. -build:mkl_open_source_only --define=using_mkl_dnn_only=true +build:mkl_open_source_only --define=build_with_mkl_dnn_only=true +build:mkl_open_source_only --define=build_with_mkl=true --define=enable_mkl=true build:download_clang --crosstool_top=@local_config_download_clang//:toolchain build:download_clang --define=using_clang=true +# Instruct clang to use LLD for linking. +# This only works with GPU builds currently, since Bazel sets -B/usr/bin in +# auto-generated CPU crosstool, forcing /usr/bin/ld.lld to be preferred over +# the downloaded one. +build:download_clang_use_lld --linkopt='-fuse-ld=lld' build:cuda --crosstool_top=@local_config_cuda//crosstool:toolchain build:cuda --define=using_cuda=true --define=using_cuda_nvcc=true +build:rocm --crosstool_top=@local_config_rocm//crosstool:toolchain +build:rocm --define=using_rocm=true --define=using_rocm_hipcc=true + build:cuda_clang --crosstool_top=@local_config_cuda//crosstool:toolchain build:cuda_clang --define=using_cuda=true --define=using_cuda_clang=true --define=using_clang=true @@ -52,6 +61,18 @@ build:sycl_asan --define=using_sycl=true --define=using_trisycl=false --copt -fn build:sycl_trisycl --crosstool_top=@local_config_sycl//crosstool:toolchain build:sycl_trisycl --define=using_sycl=true --define=using_trisycl=true +# Options extracted from configure script +build:gdr --define=with_gdr_support=true +build:ngraph --define=with_ngraph_support=true +build:verbs --define=with_verbs_support=true + +# Options to disable default on features +build:noaws --define=no_aws_support=true +build:nogcp --define=no_gcp_support=true +build:nohdfs --define=no_hdfs_support=true +build:nokafka --define=no_kafka_support=true +build:noignite --define=no_ignite_support=true + build --define=use_fast_cpp_protos=true build --define=allow_oversize_protos=true build --define=grpc_no_ares=true @@ -60,5 +81,15 @@ build --spawn_strategy=standalone build --genrule_strategy=standalone build -c opt +# Other build flags. +build --define=grpc_no_ares=true + # Modular TF build options build:dynamic_kernels --define=dynamic_loaded_kernels=true + +# Default paths for TF_SYSTEM_LIBS +build --define=PREFIX=/usr +build --define=LIBDIR=$(PREFIX)/lib +build --define=INCLUDEDIR=$(PREFIX)/include + +# Do not commit the tf_configure.bazelrc line diff --git a/.gitignore b/.gitignore index 1709610fcd3b46910d703fe7244980e3dd2c2521..cb65f447d4a551266e237714a16d71b58bcfc51d 100644 --- a/.gitignore +++ b/.gitignore @@ -1,7 +1,6 @@ .DS_Store .ipynb_checkpoints node_modules -/.bazelrc /.tf_configure.bazelrc /bazel-* /bazel_pip @@ -14,6 +13,7 @@ __pycache__ *.swp .vscode/ cmake_build/ +tensorflow/contrib/cmake/_build/ .idea/** /build/ [Bb]uild/ diff --git a/CODEOWNERS b/CODEOWNERS index b9f0313cc6d59d3fbdcd014e1a528126d863075a..94cc865479cd6ab5cdb589490d3a2d650f06b160 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -1,53 +1,67 @@ -# NOTE: Disabled temporarily because it's too noisy on pushes. # Where component owners are known, add them here. -# /tensorflow/core/platform/windows/ @mrry -# /tensorflow/java/ @asimshankar -# /tensorflow/tensorboard/ @jart @dandelionmane -# /tensorflow/tools/docs/ @markdaoust +/tenosrflow/core/debug @caisq +/tensorflow/core/platform/windows/ @mrry +/tensorflow/core/platform/s3 @yongtang +/tensorflow/go @asimshankar +/tensorflow/java/ @asimshankar +/tensorflow/python/debug @caisq +/tensorflow/python/tools/api/generator/ @annarev +/tensorflow/tensorboard/ @jart +/tensorflow/tools/docs/ @markdaoust # contrib -# NEED OWNER: /tensorflow/contrib/avro/ -# /tensorflow/contrib/batching/ @alextp @chrisolston -# /tensorflow/contrib/bayesflow/ @ebrevdo @rsepassi @jvdillon -# /tensorflow/contrib/boosted_trees/ @sshrdp @yk5 @nataliaponomareva -# /tensorflow/contrib/cmake/ @mrry @benoitsteiner -# /tensorflow/contrib/copy_graph/ @tucker @poxvoculi -# /tensorflow/contrib/crf/ @kentonl -# /tensorflow/contrib/data/ @mrry -# /tensorflow/contrib/distributions/ @jvdillon @langmore @rsepassi -# /tensorflow/contrib/factorization/ @agarwal-ashish @xavigonzalvo -# /tensorflow/contrib/ffmpeg/ @fredbertsch -# NEED OWNER: /tensorflow/contrib/framework/ -# /tensorflow/contrib/graph_editor/ @purpledog +# NEED OWNER: /tensorflow/contrib/all_reduce +/tensorflow/contrib/batching/ @alextp @chrisolston +/tensorflow/contrib/bayesflow/ @ebrevdo @rsepassi @jvdillon +/tensorflow/contrib/boosted_trees/ @sshrdp @yk5 @nataliaponomareva +/tensorflow/contrib/checkpoint/ @allenlavoie +/tensorflow/contrib/contrib/cluster_resolver/ @frankchn +/tensorflow/contrib/cmake/ @mrry +/tensorflow/contrib/copy_graph/ @tucker @poxvoculi +/tensorflow/contrib/crf/ @kentonl +/tensorflow/contrib/data/ @mrry +/tensorflow/tensorflow/contrib/distribute @joshl @priyag @sourabhbajaj @frankchn +/tensorflow/contrib/distributions/ @jvdillon @langmore @rsepassi +/tensorflow/contrib/eager @alextp @asimshankar +/tensorflow/contrib/factorization/ @agarwal-ashish @xavigonzalvo +/tensorflow/contrib/ffmpeg/ @fredbertsch +/tensorflow/contrib/framework/ @ebrevdo +/tensorflow/contrib/gan/ @joel-shor +/tensorflow/contrib/graph_editor/ @purpledog # NEED OWNER: /tensorflow/contrib/grid_rnn/ -# /tensorflow/contrib/hvx/ @satok16 -# /tensorflow/contrib/integrate/ @shoyer -# /tensorflow/contrib/kernel_methods/ @petrosmol -# /tensorflow/contrib/ios_examples/ @petewarden -# /tensorflow/contrib/labeled_tensor/ @shoyer -# /tensorflow/contrib/layers/ @fchollet @martinwicke -# /tensorflow/contrib/learn/ @martinwicke @ispirmustafa @alextp -# /tensorflow/contrib/linalg/ @langmore -# /tensorflow/contrib/linear_optimizer/ @petrosmol @andreasst @katsiapis -# /tensorflow/contrib/lookup/ @ysuematsu @andreasst -# /tensorflow/contrib/losses/ @alextp @ispirmustafa -# /tensorflow/contrib/makefile/ @petewarden @satok16 @wolffg -# /tensorflow/contrib/metrics/ @alextp @honkentuber @ispirmustafa -# /tensorflow/contrib/nccl/ @cwhipkey @zheng-xq -# /tensorflow/contrib/opt/ @strategist333 -# /tensorflow/contrib/pi_examples/ @maciekcc -# /tensorflow/contrib/quantization/ @petewarden @cwhipkey @keveman -# /tensorflow/contrib/rnn/ @ebrevdo -# /tensorflow/contrib/saved_model/ @nfiedel @sukritiramesh -# /tensorflow/contrib/seq2seq/ @lukaszkaiser -# /tensorflow/contrib/session_bundle/ @nfiedel @sukritiramesh -# /tensorflow/contrib/slim/ @sguada @thenbasilmanran -# /tensorflow/contrib/stateless/ @girving -# /tensorflow/contrib/tensor_forest/ @gilberthendry @thomascolthurst @yupbank -# /tensorflow/contrib/testing/ @dandelionmane -# /tensorflow/contrib/timeseries/ @allenlavoie -# /tensorflow/contrib/tpu/ @frankchn @saeta @jhseu -# /tensorflow/contrib/training/ @joel-shor @ebrevdo -# /tensorflow/contrib/util/ @sherrym +/tensorflow/contrib/hadoop @yongtang +/tensorflow/contrib/hvx/ @satok16 +/tensorflow/contrib/integrate/ @shoyer +/tensorflow/contrib/kafka @yongtang +/tensorflow/contrib/kernel_methods/ @petrosmol +/tensorflow/contrib/kinesis @yongtang +/tensorflow/contrib/ios_examples/ @petewarden +/tensorflow/contrib/labeled_tensor/ @shoyer +/tensorflow/contrib/layers/ @fchollet @martinwicke +/tensorflow/contrib/learn/ @martinwicke @ispirmustafa @alextp +/tensorflow/contrib/linear_optimizer/ @petrosmol @andreasst @katsiapis +/tensorflow/contrib/lookup/ @ysuematsu @andreasst +/tensorflow/contrib/losses/ @alextp @ispirmustafa +/tensorflow/contrib/makefile/ @petewarden @satok16 @wolffg +/tensorflow/contrib/metrics/ @alextp @honkentuber @ispirmustafa +/tensorflow/contrib/nccl/ @cwhipkey @zheng-xq +/tensorflow/contrib/opt/ @strategist333 @alextp +/tensorflow/contrib/pi_examples/ @maciekcc +/tensorflow/contrib/quantization/ @petewarden +/tensorflow/contrib/rnn/ @ebrevdo @scottzhu +/tensorflow/contrib/saved_model/ @nfiedel @sukritiramesh @allenl +/tensorflow/contrib/seq2seq/ @ebrevdo @lmthang +/tensorflow/contrib/session_bundle/ @nfiedel @sukritiramesh +/tensorflow/contrib/slim/ @sguada @thenbasilmanran +/tensorflow/contrib/stateless/ @girving @alextp +/tensorflow/contrib/tensor_forest/ @gilberthendry @thomascolthurst @yupbank +/tensorflow/contrib/tensorrt/ @aaroey +# NEED OWNER: /tensorflow/contrib/testing/ +/tensorflow/contrib/timeseries/ @allenlavoie +/tensorflow/contrib/tpu/ @frankchn @saeta @jhseu @sourabhbajaj +/tensorflow/contrib/training/ @joel-shor @ebrevdo +/tensorflow/contrib/util/ @sherrym + +/third_party/systemlibs/ @perfinion diff --git a/README.md b/README.md index 823c6880967a29f3e4838f7c120961c1b16e2b5f..c3455474260b2db56f1f585b70af9c259704d01a 100644 --- a/README.md +++ b/README.md @@ -29,7 +29,21 @@ subscribing to [announce@tensorflow.org](https://groups.google.com/a/tensorflow.org/forum/#!forum/announce). ## Installation -*See [Installing TensorFlow](https://www.tensorflow.org/get_started/os_setup.html) for instructions on how to install our release binaries or how to build from source.* + +To install the current release for CPU-only: + +``` +pip install tensorflow +``` + +Use the GPU package for CUDA-enabled GPU cards: + +``` +pip install tensorflow-gpu +``` + +*See [Installing TensorFlow](https://www.tensorflow.org/install) for detailed +instructions, and how to build from source.* People who are a little more adventurous can also try our nightly binaries: @@ -48,15 +62,12 @@ $ python ``` ```python >>> import tensorflow as tf +>>> tf.enable_eager_execution() +>>> tf.add(1, 2) +3 >>> hello = tf.constant('Hello, TensorFlow!') ->>> sess = tf.Session() ->>> sess.run(hello) +>>> hello.numpy() 'Hello, TensorFlow!' ->>> a = tf.constant(10) ->>> b = tf.constant(32) ->>> sess.run(a + b) -42 ->>> sess.close() ``` Learn more examples about how to do specific tasks in TensorFlow at the [tutorials page of tensorflow.org](https://www.tensorflow.org/tutorials/). @@ -90,27 +101,29 @@ The TensorFlow project strives to abide by generally accepted best practices in | **Windows CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-cpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-cpu.html) | [pypi](https://pypi.org/project/tf-nightly/) | | **Windows GPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-gpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-gpu.html) | [pypi](https://pypi.org/project/tf-nightly-gpu/) | | **Android** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/android.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/android.html) | [![Download](https://api.bintray.com/packages/google/tensorflow/tensorflow/images/download.svg)](https://bintray.com/google/tensorflow/tensorflow/_latestVersion) | +| **Raspberry Pi 0 and 1** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi01-py2.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi01-py2.html) [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi01-py3.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi01-py3.html) | [Py2](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp27-none-linux_armv6l.whl) [Py3](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp34-none-linux_armv6l.whl) | +| **Raspberry Pi 2 and 3** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi23-py2.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi23-py2.html) [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi23-py3.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi23-py3.html) | [Py2](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp27-none-linux_armv7l.whl) [Py3](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp34-none-linux_armv7l.whl) | ### Community Supported Builds -| Build Type | Status | Artifacts | -| --- | --- | --- | -| **IBM s390x** | [![Build Status](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/badge/icon)](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/) | TBA | -| **IBM ppc64le CPU** | [![Build Status](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_CPU/badge/icon)](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_CPU/) | TBA | -| **IBM ppc64le GPU** | [![Build Status](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_PPC64LE_GPU/badge/icon)](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_PPC64LE_GPU/) | TBA | -| **Linux CPU with Intel® MKL-DNN** Nightly | [![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-linux-cpu/badge/icon)](https://tensorflow-ci.intel.com/job/tensorflow-mkl-linux-cpu/) | [Nightly](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-whl-nightly/) | -| **Linux CPU with Intel® MKL-DNN** Python 2.7
**Linux CPU with Intel® MKL-DNN** Python 3.5
**Linux CPU with Intel® MKL-DNN** Python 3.6 | [![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/badge/icon)](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/lastStableBuild)|[1.10.0 py2.7](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.10.0-cp27-cp27mu-linux_x86_64.whl)
[1.10.0 py3.5](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.10.0-cp35-cp35m-linux_x86_64.whl)
[1.10.0 py3.6](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.10.0-cp36-cp36m-linux_x86_64.whl) | - +Build Type | Status | Artifacts +---------------------------------------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | --------- +**IBM s390x** | [![Build Status](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/badge/icon)](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/) | TBA +**IBM ppc64le CPU** | [![Build Status](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_CPU/badge/icon)](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_CPU/) | TBA +**IBM ppc64le GPU** Nightly | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Nightly_Artifact/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Nightly_Artifact/) | [Nightly](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Nightly_Artifact/) +**IBM ppc64le GPU** Stable Release | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/) | [Release](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/) +**Linux CPU with Intel® MKL-DNN** Nightly | [![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-linux-cpu/badge/icon)](https://tensorflow-ci.intel.com/job/tensorflow-mkl-linux-cpu/) | [Nightly](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-whl-nightly/) +**Linux CPU with Intel® MKL-DNN** Python 2.7
**Linux CPU with Intel® MKL-DNN** Python 3.5
**Linux CPU with Intel® MKL-DNN** Python 3.6 | [![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/badge/icon)](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/lastStableBuild) | [1.10.0 py2.7](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.10.0-cp27-cp27mu-linux_x86_64.whl)
[1.10.0 py3.5](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.10.0-cp35-cp35m-linux_x86_64.whl)
[1.10.0 py3.6](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.10.0-cp36-cp36m-linux_x86_64.whl) ## For more information -* [Tensorflow Blog](https://medium.com/tensorflow) -* [TensorFlow Course at Stanford](https://web.stanford.edu/class/cs20si) +* [TensorFlow Website](https://www.tensorflow.org) +* [TensorFlow Tutorials](https://www.tensorflow.org/tutorials/) * [TensorFlow Model Zoo](https://github.com/tensorflow/models) -* [TensorFlow MOOC on Udacity](https://www.udacity.com/course/deep-learning--ud730) +* [TensorFlow Twitter](https://twitter.com/tensorflow) +* [TensorFlow Blog](https://medium.com/tensorflow) +* [TensorFlow Course at Stanford](https://web.stanford.edu/class/cs20si) * [TensorFlow Roadmap](https://www.tensorflow.org/community/roadmap) -* [Tensorflow Twitter](https://twitter.com/tensorflow) -* [TensorFlow Website](https://www.tensorflow.org) * [TensorFlow White Papers](https://www.tensorflow.org/about/bib) * [TensorFlow YouTube Channel](https://www.youtube.com/channel/UC0rqucBdTuFTjJiefW5t-IQ) diff --git a/RELEASE.md b/RELEASE.md index 763ef3b279dde209ed387534032deae40a33a9e4..20e1d9217b7684e696d0abf427eef9ab9548d1b7 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -1,3 +1,86 @@ +# Release 1.11.0 + +## Major Features and Improvements + +* Nvidia GPU: + * Prebuilt binaries are now (as of TensorFlow 1.11) built against cuDNN 7.2 and TensorRT 4. See updated install guides: [Installing TensorFlow on Ubuntu](https://www.tensorflow.org/install/install_linux#tensorflow_gpu_support) +* Google Cloud TPU: + * Experimental tf.data integration for Keras on Google Cloud TPUs. + * Experimental / preview support for eager execution on Google Cloud TPUs. +* DistributionStrategy: + * Add multi-GPU DistributionStrategy support in tf.keras. Users can now use `fit`, `evaluate` and `predict` to distribute their model on multiple GPUs. + * Add multi-worker DistributionStrategy and standalone client support in Estimator. See [README] (https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/distribute) for more details. +* Add C, C++, and Python functions for querying kernels + +## Breaking Changes + +* Keras: + * The default values for tf.keras `RandomUniform`, `RandomNormal`, and `TruncatedNormal` initializers have been changed to match those in external Keras. + * Breaking change: `model.get_config()` on a Sequential model now returns a config dictionary (consistent with other Model instances) instead of a list of configs for the underlying layers. + +## Bug Fixes and Other Changes + +* C++: + * Changed the signature of SessionFactory::NewSession so that it can return a meaningful error message on failure. +* tf.data: + * Remove `num_parallel_parser_calls` argument from `tf.contrib.data.make_csv_dataset()`. [tf.data] Remove `num_parallel_parser_calls` argument from `tf.contrib.data.make_csv_dataset()`. + * `tf.data.Dataset.list_files()` raises an exception at initialization time if the argument matches no files. + * Renamed BigTable class to BigtableTable for clarity + * Document use of the Cloud Bigtable API + * Adding `tf.contrib.data.reduce_dataset` which can be used to reduce a dataset to a single element. + * Generalization of `tf.contrib.data.sliding_window_batch`. +* INC: + * Runtime improvements to triangular solve. +* `tf.contrib`: + * Add an `implementation` argument to `tf.keras.layers.LocallyConnected2D` and `tf.keras.layers.LocallyConnected1D`. The new mode (`implementation=2`) performs forward pass as a single dense matrix multiplication, allowing dramatic speedups in certain scenarios (but worse performance in others - see docstring). The option also allows to use `padding=same`. + * Add documentation clarifying the differences between tf.fill and tf.constant. + * Add experimental IndexedDatasets. + * Add selective registration target using the lite proto runtime. + * Add simple Tensor and DataType classes to TensorFlow Lite Java + * Add support for bitcasting to/from uint32 and uint64. + * Added a subclass of Estimator that can be created from a SavedModel (SavedModelEstimator). + * Adds leaf index modes as an argument. + * Allow a different output shape from the input in tf.contrib.image.transform. + * Change the state_size order of the StackedRNNCell to be natural order. To keep the existing behavior, user can add reverse_state_order=True when constructing the StackedRNNCells. + * Deprecate self.test_session() in favor of self.session() or self.cached_session(). + * Directly import tensor.proto.h (the transitive import will be removed from tensor.h soon) + * Estimator.train() now supports tf.contrib.summary.\* summaries out of the box; each call to .train() will now create a separate tfevents file rather than re-using a shared one. + * Fix FTRL L2-shrinkage behavior: the gradient from the L2 shrinkage term should not end up in the accumulator. + * Fix toco compilation/execution on Windows + * GoogleZoneProvider class added to detect which Google Cloud Engine zone tensorflow is running in. + * It is now safe to call any of the C API's TF_Delete\* functions on nullptr + * Log some errors on Android to logcat + * Match FakeQuant numerics in TFLite to improve accuracy of TFLite quantized inference models. + * Optional bucket location check for the GCS Filesystem. + * Performance enhancements for StringSplitOp & StringSplitV2Op. + * Performance improvements for regex replace operations. + * TFRecordWriter now raises an error if .write() fails. + * TPU: More helpful error messages in TPUClusterResolvers. + * The legacy_init_op argument to SavedModelBuilder methods for adding MetaGraphs has been deprecated. Please use the equivalent main_op argument instead. As part of this, we now explicitly check for a single main_op or legacy_init_op at the time of SavedModel building, whereas the check on main_op was previously only done at load time. + * The protocol used for Estimator training is now configurable in RunConfig. + * Triangular solve performance improvements. + * Unify RNN cell interface between TF and Keras. Add new get_initial_state() to Keras and TF RNN cell, which will use to replace the existing zero_state() method. + * Update initialization of variables in Keras. + * Updates to "constrained_optimization" in tensorflow/contrib. + * boosted trees: adding pruning mode + * tf.train.Checkpoint does not delete old checkpoints by default. + * tfdbg: Limit the total disk space occupied by dumped tensor data to 100 GBytes. Add environment variable `TFDBG_DISK_BYTES_LIMIT` to allow adjustment of this upper limit. + +## Thanks to our Contributors + +This release contains contributions from many people at Google, as well as: + +Aapeli, adoda, Ag Ramesh, Amogh Mannekote, Andrew Gibiansky, Andy Craze, Anirudh Koul, Aurelien Geron, Avijit, Avijit-Nervana, Ben, Benjamin H. Myara, bhack, Brett Koonce, Cao Zongyan, cbockman, cheerss, Chikanaga Tomoyuki, Clayne Robison, cosine0, Cui Wei, Dan J, David, David Norman, Dmitry Klimenkov, Eliel Hojman, Florian Courtial, fo40225, formath, Geoffrey Irving, gracehoney, Grzegorz Pawelczak, Guoliang Hua, Guozhong Zhuang, Herman Zvonimir DošIlović, HuiyangFei, Jacker, Jan HüNnemeyer, Jason Taylor, Jason Zaman, Jesse, Jiang,Zhoulong, Jiawei Zhang, Jie, Joe Yearsley, Johannes Schmitz, Jon Perl, Jon Triebenbach, Jonathan, Jonathan Hseu, Jongmin Park, Justin Shenk, karl@kubx.ca, Kate Hodesdon, Kb Sriram, Keishi Hattori, Kenneth Blomqvist, Koan-Sin Tan, Li Liangbin, Li, Yiqiang, Loo Rong Jie, Madiyar, Mahmoud Abuzaina, Mark Ryan, Matt Dodge, mbhuiyan, melvinljy96, Miguel Mota, Nafis Sadat, Nathan Luehr, naurril, Nehal J Wani, Niall Moran, Niranjan Hasabnis, Nishidha Panpaliya, npow, olicht, Pei Zhang, Peng Wang (Simpeng), Peng Yu, Philipp Jund, Pradeep Banavara, Pratik Kalshetti, qwertWZ, Rakesh Chada, Randy West, Ray Kim, Rholais Lii, Robin Richtsfeld, Rodrigo Silveira, Ruizhi, Santosh Kumar, Seb Bro, Sergei Lebedev, sfujiwara, Shaba Abhiram, Shashi, SneakyFish5, Soila Kavulya, Stefan Dyulgerov, Steven Winston, Sunitha Kambhampati, Surry Shome, Taehoon Lee, Thor Johnsen, Tristan Rice, TShapinsky, tucan, tucan9389, Vicente Reyes, Vilmar-Hillow, Vitaly Lavrukhin, wangershi, weidan.kong, weidankong, Wen-Heng (Jack) Chung, William D. Irons, Wim Glenn, XFeiF, Yan Facai (颜发才), Yanbo Liang, Yong Tang, Yoshihiro Yamazaki, Yuan (Terry) Tang, Yuan, Man, zhaoyongke, ÁRon +Ricardo Perez-Lopez, 张天启, 张晓飞 + + +# Release 1.10.1 +## Bug Fixes and Other Changes + +* `tf.keras`: + * Fixing keras on Cloud TPUs. No new binaries will be built for Windows. + + # Release 1.10.0 ## Major Features And Improvements @@ -11,7 +94,7 @@ ## Breaking Changes -* Prebuilt binaries are now (as of TensorFlow 1.10) built against NCCL 2.2 and no longer include NCCL in the binary install. TensorFlow usage with multiple GPUs and NCCL requires upgrade to [NCCL 2.2](https://developer.nvidia.com/nccl). See updated install guides: [Installing TensorFlow on Ubuntu](https://www.tensorflow.org/install/install_linux#tensorflow_gpu_support) and [Install TensorFlow from Sources](https://www.tensorflow.org/install/install_sources#optional_install_tensorflow_for_gpu_prerequisites). +* Prebuilt binaries are now (as of TensorFlow 1.10) built against NCCL 2.2 and no longer include NCCL in the binary install. TensorFlow usage with multiple GPUs and NCCL requires upgrade to [NCCL 2.2](https://developer.nvidia.com/nccl). See updated install guides: [TensorFlow GPU support](https://www.tensorflow.org/install/gpu) and [Build TensorFlow from source](https://www.tensorflow.org/install/source). * Starting from TensorFlow 1.11, Windows builds will use Bazel. Therefore, we will drop official support for cmake. ## Bug Fixes and Other Changes diff --git a/configure.py b/configure.py index 10fee6993eb52f71e2d0ad4d4c23eb3b53adc537..b564da27227ec07713f91e925ea292b35f0f02df 100644 --- a/configure.py +++ b/configure.py @@ -35,24 +35,30 @@ except ImportError: _DEFAULT_CUDA_VERSION = '9.0' _DEFAULT_CUDNN_VERSION = '7' -_DEFAULT_NCCL_VERSION = '2.2' _DEFAULT_CUDA_COMPUTE_CAPABILITIES = '3.5,7.0' _DEFAULT_CUDA_PATH = '/usr/local/cuda' _DEFAULT_CUDA_PATH_LINUX = '/opt/cuda' _DEFAULT_CUDA_PATH_WIN = ('C:/Program Files/NVIDIA GPU Computing ' 'Toolkit/CUDA/v%s' % _DEFAULT_CUDA_VERSION) -_DEFAULT_TENSORRT_PATH_LINUX = '/usr/lib/%s-linux-gnu' % platform.machine() _TF_OPENCL_VERSION = '1.2' _DEFAULT_COMPUTECPP_TOOLKIT_PATH = '/usr/local/computecpp' _DEFAULT_TRISYCL_INCLUDE_DIR = '/usr/local/triSYCL/include' -_SUPPORTED_ANDROID_NDK_VERSIONS = [10, 11, 12, 13, 14, 15] +_SUPPORTED_ANDROID_NDK_VERSIONS = [10, 11, 12, 13, 14, 15, 16] _DEFAULT_PROMPT_ASK_ATTEMPTS = 10 -_TF_WORKSPACE_ROOT = os.path.abspath(os.path.dirname(__file__)) _TF_BAZELRC_FILENAME = '.tf_configure.bazelrc' -_TF_BAZELRC = os.path.join(_TF_WORKSPACE_ROOT, _TF_BAZELRC_FILENAME) -_TF_WORKSPACE = os.path.join(_TF_WORKSPACE_ROOT, 'WORKSPACE') +_TF_WORKSPACE_ROOT = '' +_TF_BAZELRC = '' + +NCCL_LIB_PATHS = [ + 'lib64/', 'lib/powerpc64le-linux-gnu/', 'lib/x86_64-linux-gnu/', '' +] + +if platform.machine() == 'ppc64le': + _DEFAULT_TENSORRT_PATH_LINUX = '/usr/lib/powerpc64le-linux-gnu/' +else: + _DEFAULT_TENSORRT_PATH_LINUX = '/usr/lib/%s-linux-gnu' % platform.machine() class UserInputError(Exception): @@ -153,14 +159,18 @@ def get_python_path(environ_cp, python_bin_path): if environ_cp.get('PYTHONPATH'): python_paths = environ_cp.get('PYTHONPATH').split(':') try: - library_paths = run_shell( - [python_bin_path, '-c', - 'import site; print("\\n".join(site.getsitepackages()))']).split('\n') + library_paths = run_shell([ + python_bin_path, '-c', + 'import site; print("\\n".join(site.getsitepackages()))' + ]).split('\n') except subprocess.CalledProcessError: - library_paths = [run_shell( - [python_bin_path, '-c', - 'from distutils.sysconfig import get_python_lib;' - 'print(get_python_lib())'])] + library_paths = [ + run_shell([ + python_bin_path, '-c', + 'from distutils.sysconfig import get_python_lib;' + 'print(get_python_lib())' + ]) + ] all_paths = set(python_paths + library_paths) @@ -187,8 +197,7 @@ def setup_python(environ_cp): environ_cp, 'PYTHON_BIN_PATH', ask_python_bin_path, default_python_bin_path) # Check if the path is valid - if os.path.isfile(python_bin_path) and os.access( - python_bin_path, os.X_OK): + if os.path.isfile(python_bin_path) and os.access(python_bin_path, os.X_OK): break elif not os.path.exists(python_bin_path): print('Invalid python path: %s cannot be found.' % python_bin_path) @@ -217,7 +226,7 @@ def setup_python(environ_cp): python_lib_path = default_python_lib_path environ_cp['PYTHON_LIB_PATH'] = python_lib_path - python_major_version = get_python_major_version(python_bin_path) + _ = get_python_major_version(python_bin_path) # Convert python path to Windows style before writing into bazel.rc if is_windows() or is_cygwin(): @@ -230,15 +239,16 @@ def setup_python(environ_cp): environ_cp['PYTHON_BIN_PATH'] = python_bin_path # Write tools/python_bin_path.sh - with open(os.path.join( - _TF_WORKSPACE_ROOT, 'tools', 'python_bin_path.sh'), 'w') as f: + with open( + os.path.join(_TF_WORKSPACE_ROOT, 'tools', 'python_bin_path.sh'), + 'w') as f: f.write('export PYTHON_BIN_PATH="%s"' % python_bin_path) -def reset_tf_configure_bazelrc(workspace_path): +def reset_tf_configure_bazelrc(): """Reset file that contains customized config settings.""" open(_TF_BAZELRC, 'w').close() - bazelrc_path = os.path.join(workspace_path, '.bazelrc') + bazelrc_path = os.path.join(_TF_WORKSPACE_ROOT, '.bazelrc') data = [] if os.path.exists(bazelrc_path): @@ -249,20 +259,15 @@ def reset_tf_configure_bazelrc(workspace_path): if _TF_BAZELRC_FILENAME in l: continue f.write('%s\n' % l) - if is_windows(): - tf_bazelrc_path = _TF_BAZELRC.replace("\\", "/") - else: - tf_bazelrc_path = _TF_BAZELRC - f.write('import %s\n' % tf_bazelrc_path) - + f.write('import %%workspace%%/%s\n' % _TF_BAZELRC_FILENAME) def cleanup_makefile(): """Delete any leftover BUILD files from the Makefile build. These files could interfere with Bazel parsing. """ - makefile_download_dir = os.path.join( - _TF_WORKSPACE_ROOT, 'tensorflow', 'contrib', 'makefile', 'downloads') + makefile_download_dir = os.path.join(_TF_WORKSPACE_ROOT, 'tensorflow', + 'contrib', 'makefile', 'downloads') if os.path.isdir(makefile_download_dir): for root, _, filenames in os.walk(makefile_download_dir): for f in filenames: @@ -330,9 +335,8 @@ def get_var(environ_cp, 'Environment variable %s must be set as a boolean indicator.\n' 'The following are accepted as TRUE : %s.\n' 'The following are accepted as FALSE: %s.\n' - 'Current value is %s.' % ( - var_name, ', '.join(true_strings), ', '.join(false_strings), - var)) + 'Current value is %s.' % (var_name, ', '.join(true_strings), + ', '.join(false_strings), var)) while var is None: user_input_origin = get_input(question) @@ -355,8 +359,12 @@ def get_var(environ_cp, return var -def set_build_var(environ_cp, var_name, query_item, option_name, - enabled_by_default, bazel_config_name=None): +def set_build_var(environ_cp, + var_name, + query_item, + option_name, + enabled_by_default, + bazel_config_name=None): """Set if query_item will be enabled for the build. Ask user if query_item will be enabled. Default is used if no input is given. @@ -375,12 +383,14 @@ def set_build_var(environ_cp, var_name, query_item, option_name, var = str(int(get_var(environ_cp, var_name, query_item, enabled_by_default))) environ_cp[var_name] = var if var == '1': - write_to_bazelrc('build --define %s=true' % option_name) + write_to_bazelrc( + 'build:%s --define %s=true' % (bazel_config_name, option_name)) + write_to_bazelrc('build --config=%s' % bazel_config_name) elif bazel_config_name is not None: # TODO(mikecase): Migrate all users of configure.py to use --config Bazel # options and not to set build configs through environment variables. - write_to_bazelrc('build:%s --define %s=true' - % (bazel_config_name, option_name)) + write_to_bazelrc( + 'build:%s --define %s=true' % (bazel_config_name, option_name)) def set_action_env_var(environ_cp, @@ -447,7 +457,8 @@ def check_bazel_version(min_version): if which('bazel') is None: print('Cannot find bazel. Please install bazel.') sys.exit(0) - curr_version = run_shell(['bazel', '--batch', '--bazelrc=/dev/null', 'version']) + curr_version = run_shell( + ['bazel', '--batch', '--bazelrc=/dev/null', 'version']) for line in curr_version.split('\n'): if 'Build label: ' in line: @@ -486,7 +497,7 @@ def set_cc_opt_flags(environ_cp): elif is_windows(): default_cc_opt_flags = '/arch:AVX' else: - default_cc_opt_flags = '-march=native' + default_cc_opt_flags = '-march=native -Wno-sign-compare' question = ('Please specify optimization flags to use during compilation when' ' bazel option "--config=opt" is specified [Default is %s]: ' ) % default_cc_opt_flags @@ -499,6 +510,7 @@ def set_cc_opt_flags(environ_cp): write_to_bazelrc('build:opt --host_copt=-march=native') write_to_bazelrc('build:opt --define with_default_optimizations=true') + def set_tf_cuda_clang(environ_cp): """set TF_CUDA_CLANG action_env. @@ -581,16 +593,14 @@ def set_clang_cuda_compiler_path(environ_cp): clang_cuda_compiler_path) -def prompt_loop_or_load_from_env( - environ_cp, - var_name, - var_default, - ask_for_var, - check_success, - error_msg, - suppress_default_error=False, - n_ask_attempts=_DEFAULT_PROMPT_ASK_ATTEMPTS -): +def prompt_loop_or_load_from_env(environ_cp, + var_name, + var_default, + ask_for_var, + check_success, + error_msg, + suppress_default_error=False, + n_ask_attempts=_DEFAULT_PROMPT_ASK_ATTEMPTS): """Loop over user prompts for an ENV param until receiving a valid response. For the env param var_name, read from the environment or verify user input @@ -629,9 +639,7 @@ def prompt_loop_or_load_from_env( ) for _ in range(n_ask_attempts): - val = get_from_env_or_user_or_default(environ_cp, - var_name, - full_query, + val = get_from_env_or_user_or_default(environ_cp, var_name, full_query, default) if check_success(val): break @@ -639,9 +647,9 @@ def prompt_loop_or_load_from_env( print(error_msg % val) environ_cp[var_name] = '' else: - raise UserInputError('Invalid %s setting was provided %d times in a row. ' - 'Assuming to be a scripting mistake.' % - (var_name, n_ask_attempts)) + raise UserInputError( + 'Invalid %s setting was provided %d times in a row. ' + 'Assuming to be a scripting mistake.' % (var_name, n_ask_attempts)) environ_cp[var_name] = val return val @@ -650,8 +658,8 @@ def prompt_loop_or_load_from_env( def create_android_ndk_rule(environ_cp): """Set ANDROID_NDK_HOME and write Android NDK WORKSPACE rule.""" if is_windows() or is_cygwin(): - default_ndk_path = cygpath('%s/Android/Sdk/ndk-bundle' % - environ_cp['APPDATA']) + default_ndk_path = cygpath( + '%s/Android/Sdk/ndk-bundle' % environ_cp['APPDATA']) elif is_macos(): default_ndk_path = '%s/library/Android/Sdk/ndk-bundle' % environ_cp['HOME'] else: @@ -668,8 +676,7 @@ def create_android_ndk_rule(environ_cp): ask_for_var='Please specify the home path of the Android NDK to use.', check_success=valid_ndk_path, error_msg=('The path %s or its child file "source.properties" ' - 'does not exist.') - ) + 'does not exist.')) write_action_env_to_bazelrc('ANDROID_NDK_HOME', android_ndk_home_path) write_action_env_to_bazelrc('ANDROID_NDK_API_LEVEL', check_ndk_level(android_ndk_home_path)) @@ -703,9 +710,9 @@ def create_android_sdk_rule(environ_cp): api_levels = [x.replace('android-', '') for x in api_levels] def valid_api_level(api_level): - return os.path.exists(os.path.join(android_sdk_home_path, - 'platforms', - 'android-' + api_level)) + return os.path.exists( + os.path.join(android_sdk_home_path, 'platforms', + 'android-' + api_level)) android_api_level = prompt_loop_or_load_from_env( environ_cp, @@ -720,9 +727,8 @@ def create_android_sdk_rule(environ_cp): versions = sorted(os.listdir(build_tools)) def valid_build_tools(version): - return os.path.exists(os.path.join(android_sdk_home_path, - 'build-tools', - version)) + return os.path.exists( + os.path.join(android_sdk_home_path, 'build-tools', version)) android_build_tools_version = prompt_loop_or_load_from_env( environ_cp, @@ -736,10 +742,8 @@ def create_android_sdk_rule(environ_cp): write_action_env_to_bazelrc('ANDROID_BUILD_TOOLS_VERSION', android_build_tools_version) - write_action_env_to_bazelrc('ANDROID_SDK_API_LEVEL', - android_api_level) - write_action_env_to_bazelrc('ANDROID_SDK_HOME', - android_sdk_home_path) + write_action_env_to_bazelrc('ANDROID_SDK_API_LEVEL', android_api_level) + write_action_env_to_bazelrc('ANDROID_SDK_HOME', android_sdk_home_path) def check_ndk_level(android_ndk_home_path): @@ -798,6 +802,7 @@ def reformat_version_sequence(version_str, sequence_count): Args: version_str: String, the version string. sequence_count: int, an integer. + Returns: string, reformatted version string. """ @@ -841,18 +846,25 @@ def set_tf_cuda_version(environ_cp): if is_windows(): cuda_rt_lib_paths = ['lib/x64/cudart.lib'] elif is_linux(): - cuda_rt_lib_paths = ['%s/libcudart.so.%s' % (x, tf_cuda_version) - for x in ['lib64', 'lib/x86_64-linux-gnu']] + cuda_rt_lib_paths = [ + '%s/libcudart.so.%s' % (x, tf_cuda_version) for x in [ + 'lib64', + 'lib/powerpc64le-linux-gnu', + 'lib/x86_64-linux-gnu', + ] + ] elif is_macos(): cuda_rt_lib_paths = ['lib/libcudart.%s.dylib' % tf_cuda_version] - cuda_toolkit_paths_full = [os.path.join(cuda_toolkit_path, x) for x in cuda_rt_lib_paths] + cuda_toolkit_paths_full = [ + os.path.join(cuda_toolkit_path, x) for x in cuda_rt_lib_paths + ] if any([os.path.exists(x) for x in cuda_toolkit_paths_full]): break # Reset and retry print('Invalid path to CUDA %s toolkit. %s cannot be found' % - (tf_cuda_version, cuda_toolkit_path_full)) + (tf_cuda_version, cuda_toolkit_paths_full)) environ_cp['TF_CUDA_VERSION'] = '' environ_cp['CUDA_TOOLKIT_PATH'] = '' @@ -872,7 +884,7 @@ def set_tf_cudnn_version(environ_cp): """Set CUDNN_INSTALL_PATH and TF_CUDNN_VERSION.""" ask_cudnn_version = ( 'Please specify the cuDNN version you want to use. ' - '[Leave empty to default to cuDNN %s.0]: ') % _DEFAULT_CUDNN_VERSION + '[Leave empty to default to cuDNN %s]: ') % _DEFAULT_CUDNN_VERSION for _ in range(_DEFAULT_PROMPT_ASK_ATTEMPTS): tf_cudnn_version = get_from_env_or_user_or_default( @@ -919,8 +931,8 @@ def set_tf_cudnn_version(environ_cp): cudnn_path_from_ldconfig) if cudnn_path_from_ldconfig: cudnn_path_from_ldconfig = cudnn_path_from_ldconfig.group(1) - if os.path.exists('%s.%s' % (cudnn_path_from_ldconfig, - tf_cudnn_version)): + if os.path.exists( + '%s.%s' % (cudnn_path_from_ldconfig, tf_cudnn_version)): cudnn_install_path = os.path.dirname(cudnn_path_from_ldconfig) break @@ -1029,7 +1041,7 @@ def set_tf_tensorrt_install_path(environ_cp): for lib_file in possible_files: if is_cuda_compatible(lib_file, cuda_ver, cudnn_ver): matches = nvinfer_pattern.search(lib_file) - if len(matches.groups()) == 0: + if not matches.groups(): continue ver_str = matches.group(1) ver = convert_version_to_int(ver_str) if len(ver_str) else 0 @@ -1085,7 +1097,7 @@ def set_tf_tensorrt_install_path(environ_cp): def set_tf_nccl_install_path(environ_cp): - """Set NCCL_INSTALL_PATH and TF_NCCL_VERSION. + """Set NCCL_INSTALL_PATH, NCCL_HDR_PATH and TF_NCCL_VERSION. Args: environ_cp: copy of the os.environ. @@ -1098,59 +1110,119 @@ def set_tf_nccl_install_path(environ_cp): raise ValueError('Currently NCCL is only supported on Linux platforms.') ask_nccl_version = ( - 'Please specify the NCCL version you want to use. If NCCL %s is not ' - 'installed, then you can use version 1.3 that can be fetched ' - 'automatically but it may have worse performance with multiple GPUs. ' - '[Default is %s]: ') % (_DEFAULT_NCCL_VERSION, _DEFAULT_NCCL_VERSION) + 'Please specify the locally installed NCCL version you want to use. ' + '[Default is to use https://github.com/nvidia/nccl]: ') for _ in range(_DEFAULT_PROMPT_ASK_ATTEMPTS): tf_nccl_version = get_from_env_or_user_or_default( - environ_cp, 'TF_NCCL_VERSION', ask_nccl_version, _DEFAULT_NCCL_VERSION) - tf_nccl_version = reformat_version_sequence(str(tf_nccl_version), 1) + environ_cp, 'TF_NCCL_VERSION', ask_nccl_version, '') + + if not tf_nccl_version: + break # No need to get install path, building the open source code. - if tf_nccl_version == '1': - break # No need to get install path, NCCL 1 is a GitHub repo. + tf_nccl_version = reformat_version_sequence(str(tf_nccl_version), 1) - # TODO(csigg): Look with ldconfig first if we can find the library in paths + # Look with ldconfig first if we can find the library in paths # like /usr/lib/x86_64-linux-gnu and the header file in the corresponding # include directory. This is where the NCCL .deb packages install them. - # Then ask the user if we should use that. Instead of a single - # NCCL_INSTALL_PATH, pass separate NCCL_LIB_PATH and NCCL_HDR_PATH to - # nccl_configure.bzl - default_nccl_path = environ_cp.get('CUDA_TOOLKIT_PATH') - ask_nccl_path = (r'Please specify the location where NCCL %s library is ' - 'installed. Refer to README.md for more details. [Default ' - 'is %s]:') % (tf_nccl_version, default_nccl_path) - nccl_install_path = get_from_env_or_user_or_default( - environ_cp, 'NCCL_INSTALL_PATH', ask_nccl_path, default_nccl_path) - # Result returned from "read" will be used unexpanded. That make "~" - # unusable. Going through one more level of expansion to handle that. - nccl_install_path = os.path.realpath(os.path.expanduser(nccl_install_path)) - if is_windows() or is_cygwin(): - nccl_install_path = cygpath(nccl_install_path) - - if is_windows(): - nccl_lib_path = 'lib/x64/nccl.lib' - elif is_linux(): - nccl_lib_path = 'lib/libnccl.so.%s' % tf_nccl_version - elif is_macos(): - nccl_lib_path = 'lib/libnccl.%s.dylib' % tf_nccl_version - - nccl_lib_path = os.path.join(nccl_install_path, nccl_lib_path) - nccl_hdr_path = os.path.join(nccl_install_path, 'include/nccl.h') - if os.path.exists(nccl_lib_path) and os.path.exists(nccl_hdr_path): - # Set NCCL_INSTALL_PATH - environ_cp['NCCL_INSTALL_PATH'] = nccl_install_path - write_action_env_to_bazelrc('NCCL_INSTALL_PATH', nccl_install_path) - break - - # Reset and Retry - print('Invalid path to NCCL %s toolkit, %s or %s not found. Please use the ' + # First check to see if NCCL is in the ldconfig. + # If its found, use that location. + if is_linux(): + ldconfig_bin = which('ldconfig') or '/sbin/ldconfig' + nccl2_path_from_ldconfig = run_shell([ldconfig_bin, '-p']) + nccl2_path_from_ldconfig = re.search('.*libnccl.so .* => (.*)', + nccl2_path_from_ldconfig) + if nccl2_path_from_ldconfig: + nccl2_path_from_ldconfig = nccl2_path_from_ldconfig.group(1) + if os.path.exists('%s.%s' % (nccl2_path_from_ldconfig, tf_nccl_version)): + nccl_install_path = os.path.dirname(nccl2_path_from_ldconfig) + print('NCCL libraries found in ' + nccl2_path_from_ldconfig) + + # Check if this is the main system lib location + if re.search('.*linux-gnu', nccl_install_path): + trunc_nccl_install_path = '/usr' + print('This looks like a system path.') + else: + trunc_nccl_install_path = nccl_install_path + '/..' + + # Look for header + nccl_hdr_path = trunc_nccl_install_path + '/include' + print('Assuming NCCL header path is ' + nccl_hdr_path) + if os.path.exists(nccl_hdr_path + '/nccl.h'): + # Set NCCL_INSTALL_PATH + environ_cp['NCCL_INSTALL_PATH'] = nccl_install_path + write_action_env_to_bazelrc('NCCL_INSTALL_PATH', nccl_install_path) + + # Set NCCL_HDR_PATH + environ_cp['NCCL_HDR_PATH'] = nccl_hdr_path + write_action_env_to_bazelrc('NCCL_HDR_PATH', nccl_hdr_path) + break + else: + print( + 'The header for NCCL2 cannot be found. Please install the libnccl-dev package.' + ) + else: + print('NCCL2 is listed by ldconfig but the library is not found. ' + 'Your ldconfig is out of date. Please run sudo ldconfig.') + else: + # NCCL is not found in ldconfig. Ask the user for the location. + default_nccl_path = environ_cp.get('CUDA_TOOLKIT_PATH') + ask_nccl_path = ( + r'Please specify the location where NCCL %s library is ' + 'installed. Refer to README.md for more details. [Default ' + 'is %s]:') % (tf_nccl_version, default_nccl_path) + nccl_install_path = get_from_env_or_user_or_default( + environ_cp, 'NCCL_INSTALL_PATH', ask_nccl_path, default_nccl_path) + + # Result returned from "read" will be used unexpanded. That make "~" + # unusable. Going through one more level of expansion to handle that. + nccl_install_path = os.path.realpath( + os.path.expanduser(nccl_install_path)) + if is_windows() or is_cygwin(): + nccl_install_path = cygpath(nccl_install_path) + + if is_windows(): + nccl_lib_path = 'lib/x64/nccl.lib' + elif is_linux(): + nccl_lib_filename = 'libnccl.so.%s' % tf_nccl_version + nccl_lpath = '%s/lib/%s' % (nccl_install_path, nccl_lib_filename) + if not os.path.exists(nccl_lpath): + for relative_path in NCCL_LIB_PATHS: + path = '%s/%s%s' % (nccl_install_path, relative_path, + nccl_lib_filename) + if os.path.exists(path): + print('NCCL found at ' + path) + nccl_lib_path = path + break + else: + nccl_lib_path = nccl_lpath + elif is_macos(): + nccl_lib_path = 'lib/libnccl.%s.dylib' % tf_nccl_version + + nccl_lib_path = os.path.join(nccl_install_path, nccl_lib_path) + nccl_hdr_path = os.path.join( + os.path.dirname(nccl_lib_path), '../include/nccl.h') + print('Assuming NCCL header path is ' + nccl_hdr_path) + if os.path.exists(nccl_lib_path) and os.path.exists(nccl_hdr_path): + # Set NCCL_INSTALL_PATH + environ_cp['NCCL_INSTALL_PATH'] = os.path.dirname(nccl_lib_path) + write_action_env_to_bazelrc('NCCL_INSTALL_PATH', + os.path.dirname(nccl_lib_path)) + + # Set NCCL_HDR_PATH + environ_cp['NCCL_HDR_PATH'] = os.path.dirname(nccl_hdr_path) + write_action_env_to_bazelrc('NCCL_HDR_PATH', + os.path.dirname(nccl_hdr_path)) + break + + # Reset and Retry + print( + 'Invalid path to NCCL %s toolkit, %s or %s not found. Please use the ' 'O/S agnostic package of NCCL 2' % (tf_nccl_version, nccl_lib_path, nccl_hdr_path)) - environ_cp['TF_NCCL_VERSION'] = '' + environ_cp['TF_NCCL_VERSION'] = '' else: raise UserInputError('Invalid TF_NCCL setting was provided %d ' 'times in a row. Assuming to be a scripting mistake.' % @@ -1160,12 +1232,12 @@ def set_tf_nccl_install_path(environ_cp): environ_cp['TF_NCCL_VERSION'] = tf_nccl_version write_action_env_to_bazelrc('TF_NCCL_VERSION', tf_nccl_version) - def get_native_cuda_compute_capabilities(environ_cp): """Get native cuda compute capabilities. Args: environ_cp: copy of the os.environ. + Returns: string of native cuda compute capabilities, separated by comma. """ @@ -1290,8 +1362,7 @@ def set_computecpp_toolkit_path(environ_cp): else: sycl_rt_lib_path = '' - sycl_rt_lib_path_full = os.path.join(toolkit_path, - sycl_rt_lib_path) + sycl_rt_lib_path_full = os.path.join(toolkit_path, sycl_rt_lib_path) exists = os.path.exists(sycl_rt_lib_path_full) if not exists: print('Invalid SYCL %s library path. %s cannot be found' % @@ -1319,8 +1390,8 @@ def set_trisycl_include_dir(environ_cp): ask_trisycl_include_dir = ('Please specify the location of the triSYCL ' 'include directory. (Use --config=sycl_trisycl ' 'when building with Bazel) ' - '[Default is %s]: ' - ) % (_DEFAULT_TRISYCL_INCLUDE_DIR) + '[Default is %s]: ') % ( + _DEFAULT_TRISYCL_INCLUDE_DIR) while True: trisycl_include_dir = get_from_env_or_user_or_default( @@ -1329,13 +1400,12 @@ def set_trisycl_include_dir(environ_cp): if os.path.exists(trisycl_include_dir): break - print('Invalid triSYCL include directory, %s cannot be found' - % (trisycl_include_dir)) + print('Invalid triSYCL include directory, %s cannot be found' % + (trisycl_include_dir)) # Set TRISYCL_INCLUDE_DIR environ_cp['TRISYCL_INCLUDE_DIR'] = trisycl_include_dir - write_action_env_to_bazelrc('TRISYCL_INCLUDE_DIR', - trisycl_include_dir) + write_action_env_to_bazelrc('TRISYCL_INCLUDE_DIR', trisycl_include_dir) def set_mpi_home(environ_cp): @@ -1345,8 +1415,9 @@ def set_mpi_home(environ_cp): default_mpi_home = os.path.dirname(os.path.dirname(default_mpi_home)) def valid_mpi_path(mpi_home): - exists = (os.path.exists(os.path.join(mpi_home, 'include')) and - os.path.exists(os.path.join(mpi_home, 'lib'))) + exists = ( + os.path.exists(os.path.join(mpi_home, 'include')) and + os.path.exists(os.path.join(mpi_home, 'lib'))) if not exists: print('Invalid path to the MPI Toolkit. %s or %s cannot be found' % (os.path.join(mpi_home, 'include'), @@ -1395,16 +1466,22 @@ def set_other_mpi_vars(environ_cp): raise ValueError('Cannot find the MPI library file in %s/lib' % mpi_home) -def set_grpc_build_flags(): - write_to_bazelrc('build --define grpc_no_ares=true') - - def set_system_libs_flag(environ_cp): syslibs = environ_cp.get('TF_SYSTEM_LIBS', '') - syslibs = ','.join(sorted(syslibs.split(','))) - if syslibs and syslibs != '': + if syslibs: + if ',' in syslibs: + syslibs = ','.join(sorted(syslibs.split(','))) + else: + syslibs = ','.join(sorted(syslibs.split())) write_action_env_to_bazelrc('TF_SYSTEM_LIBS', syslibs) + if 'PREFIX' in environ_cp: + write_to_bazelrc('build --define=PREFIX=%s' % environ_cp['PREFIX']) + if 'LIBDIR' in environ_cp: + write_to_bazelrc('build --define=LIBDIR=%s' % environ_cp['LIBDIR']) + if 'INCLUDEDIR' in environ_cp: + write_to_bazelrc('build --define=INCLUDEDIR=%s' % environ_cp['INCLUDEDIR']) + def set_windows_build_flags(environ_cp): """Set Windows specific build options.""" @@ -1424,11 +1501,9 @@ def set_windows_build_flags(environ_cp): if get_var( environ_cp, 'TF_OVERRIDE_EIGEN_STRONG_INLINE', 'Eigen strong inline', - True, - ('Would you like to override eigen strong inline for some C++ ' - 'compilation to reduce the compilation time?'), - 'Eigen strong inline overridden.', - 'Not overriding eigen strong inline, ' + True, ('Would you like to override eigen strong inline for some C++ ' + 'compilation to reduce the compilation time?'), + 'Eigen strong inline overridden.', 'Not overriding eigen strong inline, ' 'some compilations could take more than 20 mins.'): # Due to a known MSVC compiler issue # https://github.com/tensorflow/tensorflow/issues/10521 @@ -1444,29 +1519,31 @@ def config_info_line(name, help_text): def main(): + global _TF_WORKSPACE_ROOT + global _TF_BAZELRC + parser = argparse.ArgumentParser() - parser.add_argument("--workspace", - type=str, - default=_TF_WORKSPACE_ROOT, - help="The absolute path to your active Bazel workspace.") + parser.add_argument( + '--workspace', + type=str, + default=os.path.abspath(os.path.dirname(__file__)), + help='The absolute path to your active Bazel workspace.') args = parser.parse_args() + _TF_WORKSPACE_ROOT = args.workspace + _TF_BAZELRC = os.path.join(_TF_WORKSPACE_ROOT, _TF_BAZELRC_FILENAME) + # Make a copy of os.environ to be clear when functions and getting and setting # environment variables. environ_cp = dict(os.environ) check_bazel_version('0.15.0') - reset_tf_configure_bazelrc(args.workspace) + reset_tf_configure_bazelrc() cleanup_makefile() setup_python(environ_cp) if is_windows(): - environ_cp['TF_NEED_AWS'] = '0' - environ_cp['TF_NEED_GCP'] = '0' - environ_cp['TF_NEED_HDFS'] = '0' - environ_cp['TF_NEED_JEMALLOC'] = '0' - environ_cp['TF_NEED_KAFKA'] = '0' environ_cp['TF_NEED_OPENCL_SYCL'] = '0' environ_cp['TF_NEED_COMPUTECPP'] = '0' environ_cp['TF_NEED_OPENCL'] = '0' @@ -1475,14 +1552,10 @@ def main(): # TODO(ibiryukov): Investigate using clang as a cpu or cuda compiler on # Windows. environ_cp['TF_DOWNLOAD_CLANG'] = '0' - environ_cp['TF_ENABLE_XLA'] = '0' - environ_cp['TF_NEED_GDR'] = '0' - environ_cp['TF_NEED_VERBS'] = '0' environ_cp['TF_NEED_MPI'] = '0' environ_cp['TF_SET_ANDROID_WORKSPACE'] = '0' if is_macos(): - environ_cp['TF_NEED_JEMALLOC'] = '0' environ_cp['TF_NEED_TENSORRT'] = '0' # The numpy package on ppc64le uses OpenBLAS which has multi-threading @@ -1490,26 +1563,11 @@ def main(): # runtime to allow the Tensorflow testcases which compare numpy # results to Tensorflow results to succeed. if is_ppc64le(): - write_action_env_to_bazelrc("OMP_NUM_THREADS", 1) - - set_build_var(environ_cp, 'TF_NEED_JEMALLOC', 'jemalloc as malloc', - 'with_jemalloc', True) - set_build_var(environ_cp, 'TF_NEED_GCP', 'Google Cloud Platform', - 'with_gcp_support', True, 'gcp') - set_build_var(environ_cp, 'TF_NEED_HDFS', 'Hadoop File System', - 'with_hdfs_support', True, 'hdfs') - set_build_var(environ_cp, 'TF_NEED_AWS', 'Amazon AWS Platform', - 'with_aws_support', True, 'aws') - set_build_var(environ_cp, 'TF_NEED_KAFKA', 'Apache Kafka Platform', - 'with_kafka_support', True, 'kafka') + write_action_env_to_bazelrc('OMP_NUM_THREADS', 1) + + xla_enabled_by_default = is_linux() set_build_var(environ_cp, 'TF_ENABLE_XLA', 'XLA JIT', 'with_xla_support', - False, 'xla') - set_build_var(environ_cp, 'TF_NEED_GDR', 'GDR', 'with_gdr_support', - False, 'gdr') - set_build_var(environ_cp, 'TF_NEED_VERBS', 'VERBS', 'with_verbs_support', - False, 'verbs') - set_build_var(environ_cp, 'TF_NEED_NGRAPH', 'nGraph', - 'with_ngraph_support', False, 'ngraph') + xla_enabled_by_default, 'xla') set_action_env_var(environ_cp, 'TF_NEED_OPENCL_SYCL', 'OpenCL SYCL', False) if environ_cp.get('TF_NEED_OPENCL_SYCL') == '1': @@ -1521,6 +1579,13 @@ def main(): else: set_trisycl_include_dir(environ_cp) + set_action_env_var(environ_cp, 'TF_NEED_ROCM', 'ROCm', False) + if (environ_cp.get('TF_NEED_ROCM') == '1' and + 'LD_LIBRARY_PATH' in environ_cp and + environ_cp.get('LD_LIBRARY_PATH') != '1'): + write_action_env_to_bazelrc('LD_LIBRARY_PATH', + environ_cp.get('LD_LIBRARY_PATH')) + set_action_env_var(environ_cp, 'TF_NEED_CUDA', 'CUDA', False) if (environ_cp.get('TF_NEED_CUDA') == '1' and 'TF_CUDA_CONFIG_REPO' not in environ_cp): @@ -1543,6 +1608,10 @@ def main(): if environ_cp.get('TF_DOWNLOAD_CLANG') != '1': # Set up which clang we should use as the cuda / host compiler. set_clang_cuda_compiler_path(environ_cp) + else: + # Use downloaded LLD for linking. + write_to_bazelrc('build:cuda_clang --config=download_clang_use_lld') + write_to_bazelrc('test:cuda_clang --config=download_clang_use_lld') else: # Set up which gcc nvcc should use as the host compiler # No need to set this on Windows @@ -1557,36 +1626,56 @@ def main(): write_to_bazelrc('build --config=download_clang') write_to_bazelrc('test --config=download_clang') + # SYCL / ROCm / CUDA are mutually exclusive. + # At most 1 GPU platform can be configured. + gpu_platform_count = 0 + if environ_cp.get('TF_NEED_OPENCL_SYCL') == '1': + gpu_platform_count += 1 + if environ_cp.get('TF_NEED_ROCM') == '1': + gpu_platform_count += 1 + if environ_cp.get('TF_NEED_CUDA') == '1': + gpu_platform_count += 1 + if gpu_platform_count >= 2: + raise UserInputError('SYCL / CUDA / ROCm are mututally exclusive. ' + 'At most 1 GPU platform can be configured.') + set_build_var(environ_cp, 'TF_NEED_MPI', 'MPI', 'with_mpi_support', False) if environ_cp.get('TF_NEED_MPI') == '1': set_mpi_home(environ_cp) set_other_mpi_vars(environ_cp) - set_grpc_build_flags() set_cc_opt_flags(environ_cp) set_system_libs_flag(environ_cp) if is_windows(): set_windows_build_flags(environ_cp) - if get_var( - environ_cp, 'TF_SET_ANDROID_WORKSPACE', 'android workspace', - False, - ('Would you like to interactively configure ./WORKSPACE for ' - 'Android builds?'), - 'Searching for NDK and SDK installations.', - 'Not configuring the WORKSPACE for Android builds.'): + # Add a config option to build TensorFlow 2.0 API. + write_to_bazelrc('build:v2 --define=tf_api_version=2') + + if get_var(environ_cp, 'TF_SET_ANDROID_WORKSPACE', 'android workspace', False, + ('Would you like to interactively configure ./WORKSPACE for ' + 'Android builds?'), 'Searching for NDK and SDK installations.', + 'Not configuring the WORKSPACE for Android builds.'): create_android_ndk_rule(environ_cp) create_android_sdk_rule(environ_cp) - # On Windows, we don't have MKL support and the build is always monolithic. - # So no need to print the following message. - # TODO(pcloudy): remove the following if check when they make sense on Windows - if not is_windows(): - print('Preconfigured Bazel build configs. You can use any of the below by ' - 'adding "--config=<>" to your build command. See tools/bazel.rc for ' - 'more details.') - config_info_line('mkl', 'Build with MKL support.') - config_info_line('monolithic', 'Config for mostly static monolithic build.') + print('Preconfigured Bazel build configs. You can use any of the below by ' + 'adding "--config=<>" to your build command. See .bazelrc for more ' + 'details.') + config_info_line('mkl', 'Build with MKL support.') + config_info_line('monolithic', 'Config for mostly static monolithic build.') + config_info_line('gdr', 'Build with GDR support.') + config_info_line('verbs', 'Build with libverbs support.') + config_info_line('ngraph', 'Build with Intel nGraph support.') + + print('Preconfigured Bazel build configs to DISABLE default on features:') + config_info_line('noaws', 'Disable AWS S3 filesystem support.') + config_info_line('nogcp', 'Disable GCP support.') + config_info_line('nohdfs', 'Disable HDFS support.') + config_info_line('noignite', 'Disable Apacha Ignite support.') + config_info_line('nokafka', 'Disable Apache Kafka support.') + if __name__ == '__main__': main() + diff --git a/tensorflow/BUILD b/tensorflow/BUILD index 9cc4c4567b4b2ea6bc29919bfa03c190c9005fbc..77e3baaff198b402dc04daa1b11e4007b9906b23 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -12,6 +12,7 @@ exports_files([ # The leakr files are used by //third_party/cloud_tpu. "leakr_badwords.dic", "leakr_badfiles.dic", + "leakr_file_type_recipe.ftrcp", ]) load("//tensorflow:tensorflow.bzl", "tf_cc_shared_object") @@ -23,11 +24,25 @@ load( "//tensorflow/python/tools/api/generator:api_gen.bzl", "gen_api_init_files", # @unused ) +load("//tensorflow/python/tools/api/generator:api_gen.bzl", "get_compat_files") +load( + "//tensorflow/python/tools/api/generator:api_init_files.bzl", + "TENSORFLOW_API_INIT_FILES", # @unused +) +load( + "//tensorflow/python/tools/api/generator:api_init_files_v1.bzl", + "TENSORFLOW_API_INIT_FILES_V1", # @unused +) load( "//third_party/ngraph:build_defs.bzl", "if_ngraph", ) +# @unused +TENSORFLOW_API_INIT_FILES_V2 = ( + TENSORFLOW_API_INIT_FILES + get_compat_files(TENSORFLOW_API_INIT_FILES_V1, 1) +) + # Config setting used when building for products # which requires restricted licenses to be avoided. config_setting( @@ -188,81 +203,46 @@ config_setting( visibility = ["//visibility:public"], ) -# TODO(jhseu): Enable on other platforms other than Linux. -config_setting( - name = "with_jemalloc_linux_x86_64", - define_values = {"with_jemalloc": "true"}, - values = {"cpu": "k8"}, - visibility = ["//visibility:public"], -) - -config_setting( - name = "with_jemalloc_linux_ppc64le", - define_values = {"with_jemalloc": "true"}, - values = {"cpu": "ppc"}, - visibility = ["//visibility:public"], -) - config_setting( name = "with_default_optimizations", define_values = {"with_default_optimizations": "true"}, visibility = ["//visibility:public"], ) +# Features that are default ON are handled differently below. +# config_setting( - name = "with_gcp_support", - define_values = {"with_gcp_support": "true"}, - visibility = ["//visibility:public"], -) - -config_setting( - name = "with_hdfs_support", - define_values = {"with_hdfs_support": "true"}, - visibility = ["//visibility:public"], -) - -config_setting( - name = "with_aws_support", - define_values = {"with_aws_support": "true"}, - visibility = ["//visibility:public"], -) - -config_setting( - name = "with_kafka_support", - define_values = {"with_kafka_support": "true"}, + name = "no_aws_support", + define_values = {"no_aws_support": "false"}, visibility = ["//visibility:public"], ) -# Crosses between platforms and file system libraries not supported on those -# platforms due to limitations in nested select() statements. config_setting( - name = "with_gcp_support_windows_override", - define_values = {"with_gcp_support": "true"}, - values = {"cpu": "x64_windows"}, + name = "no_gcp_support", + define_values = {"no_gcp_support": "false"}, visibility = ["//visibility:public"], ) config_setting( - name = "with_hdfs_support_windows_override", - define_values = {"with_hdfs_support": "true"}, - values = {"cpu": "x64_windows"}, + name = "no_hdfs_support", + define_values = {"no_hdfs_support": "false"}, visibility = ["//visibility:public"], ) config_setting( - name = "with_aws_support_windows_override", - define_values = {"with_aws_support": "true"}, - values = {"cpu": "x64_windows"}, + name = "no_ignite_support", + define_values = {"no_ignite_support": "false"}, visibility = ["//visibility:public"], ) config_setting( - name = "with_kafka_support_windows_override", - define_values = {"with_kafka_support": "true"}, - values = {"cpu": "x64_windows"}, + name = "no_kafka_support", + define_values = {"no_kafka_support": "false"}, visibility = ["//visibility:public"], ) +# Crosses between platforms and file system libraries not supported on those +# platforms due to limitations in nested select() statements. config_setting( name = "with_cuda_support_windows_override", define_values = {"using_cuda_nvcc": "true"}, @@ -270,48 +250,6 @@ config_setting( visibility = ["//visibility:public"], ) -config_setting( - name = "with_gcp_support_android_override", - define_values = {"with_gcp_support": "true"}, - values = {"crosstool_top": "//external:android/crosstool"}, - visibility = ["//visibility:public"], -) - -config_setting( - name = "with_hdfs_support_android_override", - define_values = {"with_hdfs_support": "true"}, - values = {"crosstool_top": "//external:android/crosstool"}, - visibility = ["//visibility:public"], -) - -config_setting( - name = "with_aws_support_android_override", - define_values = {"with_aws_support": "true"}, - values = {"crosstool_top": "//external:android/crosstool"}, - visibility = ["//visibility:public"], -) - -config_setting( - name = "with_gcp_support_ios_override", - define_values = {"with_gcp_support": "true"}, - values = {"crosstool_top": "//tools/osx/crosstool:crosstool"}, - visibility = ["//visibility:public"], -) - -config_setting( - name = "with_hdfs_support_ios_override", - define_values = {"with_hdfs_support": "true"}, - values = {"crosstool_top": "//tools/osx/crosstool:crosstool"}, - visibility = ["//visibility:public"], -) - -config_setting( - name = "with_aws_support_ios_override", - define_values = {"with_aws_support": "true"}, - values = {"crosstool_top": "//tools/osx/crosstool:crosstool"}, - visibility = ["//visibility:public"], -) - config_setting( name = "with_xla_support", define_values = {"with_xla_support": "true"}, @@ -340,30 +278,6 @@ config_setting( visibility = ["//visibility:public"], ) -config_setting( - name = "with_jemalloc_linux_x86_64_dynamic", - define_values = { - "with_jemalloc": "true", - "framework_shared_object": "true", - }, - values = { - "cpu": "k8", - }, - visibility = ["//visibility:public"], -) - -config_setting( - name = "with_jemalloc_linux_ppc64le_dynamic", - define_values = { - "with_jemalloc": "true", - "framework_shared_object": "true", - }, - values = { - "cpu": "ppc", - }, - visibility = ["//visibility:public"], -) - config_setting( name = "using_cuda_clang", define_values = { @@ -423,12 +337,20 @@ config_setting( visibility = ["//visibility:public"], ) +# This flag specifies whether TensorFlow 2.0 API should be built instead +# of 1.* API. Note that TensorFlow 2.0 API is currently under development. +config_setting( + name = "api_version_2", + define_values = {"tf_api_version": "2"}, +) + package_group( name = "internal", packages = [ "-//third_party/tensorflow/python/estimator", "//learning/meta_rank/...", "//tensorflow/...", + "//tensorflow_estimator/...", "//tensorflow_fold/llgtm/...", "//third_party/py/tensor2tensor/...", ], @@ -541,6 +463,7 @@ tf_cc_shared_object( "$(location //tensorflow/c:version_script.lds)", ], }), + visibility = ["//visibility:public"], deps = [ "//tensorflow/c:c_api", "//tensorflow/c:c_api_experimental", @@ -565,6 +488,7 @@ tf_cc_shared_object( "$(location //tensorflow:tf_version_script.lds)", ], }), + visibility = ["//visibility:public"], deps = [ "//tensorflow:tf_exported_symbols.lds", "//tensorflow:tf_version_script.lds", @@ -585,10 +509,73 @@ exports_files( ], ) +genrule( + name = "install_headers", + srcs = [ + "//tensorflow/c:headers", + "//tensorflow/c/eager:headers", + "//tensorflow/cc:headers", + "//tensorflow/core:headers", + ], + outs = ["include"], + cmd = """ + mkdir $@ + for f in $(SRCS); do + d="$${f%/*}" + d="$${d#bazel-out*genfiles/}" + d="$${d#*external/eigen_archive/}" + + if [[ $${d} == *local_config_* ]]; then + continue + fi + + if [[ $${d} == external* ]]; then + extname="$${d#*external/}" + extname="$${extname%%/*}" + if [[ $${TF_SYSTEM_LIBS:-} == *$${extname}* ]]; then + continue + fi + fi + + mkdir -p "$@/$${d}" + cp "$${f}" "$@/$${d}/" + done + """, + tags = ["manual"], + visibility = ["//visibility:public"], +) + +genrule( + name = "root_init_gen", + srcs = select({ + "api_version_2": [":tf_python_api_gen_v2"], + "//conditions:default": [":tf_python_api_gen_v1"], + }), + outs = ["__init__.py"], + cmd = select({ + "api_version_2": "cp $(@D)/_api/v2/__init__.py $(OUTS)", + "//conditions:default": "cp $(@D)/_api/v1/__init__.py $(OUTS)", + }), +) + gen_api_init_files( - name = "tensorflow_python_api_gen", + name = "tf_python_api_gen_v1", srcs = ["api_template.__init__.py"], api_version = 1, + output_dir = "_api/v1/", + output_files = TENSORFLOW_API_INIT_FILES_V1, + output_package = "tensorflow._api.v1", + root_init_template = "api_template.__init__.py", +) + +gen_api_init_files( + name = "tf_python_api_gen_v2", + srcs = ["api_template.__init__.py"], + api_version = 2, + compat_api_versions = [1], + output_dir = "_api/v2/", + output_files = TENSORFLOW_API_INIT_FILES_V2, + output_package = "tensorflow._api.v2", root_init_template = "api_template.__init__.py", ) @@ -606,7 +593,10 @@ py_library( py_library( name = "tensorflow_py_no_contrib", - srcs = [":tensorflow_python_api_gen"], + srcs = select({ + "api_version_2": [":tf_python_api_gen_v2"], + "//conditions:default": [":tf_python_api_gen_v1"], + }) + [":root_init_gen"], srcs_version = "PY2AND3", visibility = ["//visibility:public"], deps = ["//tensorflow/python:no_contrib"], diff --git a/tensorflow/api_template.__init__.py b/tensorflow/api_template.__init__.py index 779f65d5b17c350833f67f07985b00e8eb561e72..2de740e145f93b151faf5c987808dbdf73fb4fd7 100644 --- a/tensorflow/api_template.__init__.py +++ b/tensorflow/api_template.__init__.py @@ -14,15 +14,16 @@ # ============================================================================== """Bring in all of the public TensorFlow interface into this module.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function +from __future__ import absolute_import as _absolute_import +from __future__ import division as _division +from __future__ import print_function as _print_function + +import os as _os # pylint: disable=g-bad-import-order from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import try: - import os # pylint: disable=g-import-not-at-top # Add `estimator` attribute to allow access to estimator APIs via # "tf.estimator..." from tensorflow.python.estimator.api import estimator # pylint: disable=g-import-not-at-top @@ -30,9 +31,8 @@ try: # Add `estimator` to the __path__ to allow "from tensorflow.estimator..." # style imports. from tensorflow.python.estimator import api as estimator_api # pylint: disable=g-import-not-at-top - __path__ += [os.path.dirname(estimator_api.__file__)] + __path__ += [_os.path.dirname(estimator_api.__file__)] del estimator_api - del os except (ImportError, AttributeError): print('tf.estimator package not installed.') @@ -41,19 +41,32 @@ except (ImportError, AttributeError): from tensorflow.python.util.lazy_loader import LazyLoader # pylint: disable=g-import-not-at-top contrib = LazyLoader('contrib', globals(), 'tensorflow.contrib') del LazyLoader +# The templated code that replaces the placeholder above sometimes +# sets the __all__ variable. If it does, we have to be sure to add +# "contrib". +if '__all__' in vars(): + vars()['__all__'].append('contrib') from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-top app.flags = flags # pylint: disable=undefined-variable -del absolute_import -del division -del print_function +# Make sure directory containing top level submodules is in +# the __path__ so that "from tensorflow.foo import bar" works. +_tf_api_dir = _os.path.dirname(_os.path.dirname(app.__file__)) # pylint: disable=undefined-variable +if _tf_api_dir not in __path__: + __path__.append(_tf_api_dir) # These symbols appear because we import the python package which # in turn imports from tensorflow.core and tensorflow.python. They # must come from this module. So python adds these symbols for the # resolution to succeed. # pylint: disable=undefined-variable -del python -del core +try: + del python + del core +except NameError: + # Don't fail if these modules are not available. + # For e.g. we are using this file for compat.v1 module as well and + # 'python', 'core' directories are not under compat/v1. + pass # pylint: enable=undefined-variable diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD index 8a9301d584775cff3ae315e6fd856b00d1734248..17e2e292eb19029d279bc12a8328edadf96f1bb8 100644 --- a/tensorflow/c/BUILD +++ b/tensorflow/c/BUILD @@ -117,6 +117,7 @@ tf_cuda_library( deps = [ ":c_api", ":c_api_internal", + "//tensorflow/c/eager:c_api", "//tensorflow/compiler/jit/legacy_flags:mark_for_compilation_pass_flags", "//tensorflow/contrib/tpu:all_ops", "//tensorflow/core:core_cpu", @@ -127,6 +128,15 @@ tf_cuda_library( ], ) +cc_library( + name = "c_api_headers", + hdrs = [ + "c_api.h", + ], + copts = tf_copts(), + visibility = ["//tensorflow:__subpackages__"], +) + exports_files( [ "version_script.lds", @@ -194,6 +204,7 @@ tf_cuda_cc_test( "//tensorflow:darwin": ["-headerpad_max_install_names"], "//conditions:default": [], }), + tags = ["noasan"], # We must ensure that the dependencies can be dynamically linked since # the shared library must be able to use core:framework. # linkstatic = tf_kernel_tests_linkstatic(), @@ -235,6 +246,7 @@ tf_cc_test( ":c_api_experimental", ":c_test_util", "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", ], diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index b8adf6c1279e72d0c2056368253aa0cb470216e5..79811ceae57e0bddeb2a6f32bad7003e14e23422 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -39,6 +39,7 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor.pb.h" // NOLINT #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/types.h" @@ -1240,7 +1241,7 @@ void TF_SetAttrTypeList(TF_OperationDescription* desc, const char* attr_name, void TF_SetAttrFuncName(TF_OperationDescription* desc, const char* attr_name, const char* value, size_t length) { tensorflow::NameAttrList func_name; - func_name.set_name(std::string(value, value + length)); + func_name.set_name(string(value, value + length)); desc->node_builder.Attr(attr_name, func_name); } @@ -2065,7 +2066,7 @@ static void GraphImportGraphDefLocked(TF_Graph* graph, const GraphDef& def, for (int i = 0; i < size; ++i) { TensorId id = results.missing_unused_input_map_keys[i]; - tf_results->missing_unused_key_names_data.push_back(std::string(id.first)); + tf_results->missing_unused_key_names_data.emplace_back(id.first); tf_results->missing_unused_key_names[i] = tf_results->missing_unused_key_names_data.back().c_str(); tf_results->missing_unused_key_indexes[i] = id.second; diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc index 69b3ffe2a1f620e346405607ecf742fb863aa644..d4b78138e93624a7e41e917f8210281b500661bc 100644 --- a/tensorflow/c/c_api_experimental.cc +++ b/tensorflow/c/c_api_experimental.cc @@ -17,11 +17,13 @@ limitations under the License. #include "tensorflow/c/c_api_internal.h" #include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h" +#include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/platform.h" #include "tensorflow/core/protobuf/config.pb.h" +#include "tensorflow/core/protobuf/tensorflow_server.pb.h" using tensorflow::FunctionDef; using tensorflow::Node; @@ -79,6 +81,18 @@ TF_Buffer* TF_CreateConfig(unsigned char enable_xla_compilation, auto* gpu_options = config.mutable_gpu_options(); gpu_options->set_allow_growth(gpu_memory_allow_growth); + // TODO(b/113217601): This is needed for EagerContext::runner_ to use a + // threadpool, so that we avoid the possibility of running the runner_ in the + // threadpool of GPU event mgr, as that can trigger more callbacks to be + // scheduled on that same threadpool, causing a deadlock in cases where the + // caller of event_mgr->ThenExecute() blocks on the completion of the callback + // (as in the case of ConstOp kernel creation on GPU, which involves copying a + // CPU tensor to GPU). + // Setting a larger thread pool does not help with the Swift caller, as we use + // a different TFE context for each thread of execution (for running graph + // functions, and their send/recvs corountines). + config.set_inter_op_parallelism_threads(1); + TF_Buffer* ret = TF_NewBuffer(); TF_CHECK_OK(MessageToBuffer(config, ret)); return ret; @@ -8494,3 +8508,237 @@ void TF_EnqueueNamedTensor(TF_Session* session, int tensor_id, /*run_metadata*/ nullptr, status); VLOG(1) << "Enqueuing is done."; } + +TF_Buffer* TFE_GetServerDef(const char* text_proto, TF_Status* status) { + tensorflow::ServerDef server_def; + if (!tensorflow::protobuf::TextFormat::ParseFromString(text_proto, + &server_def)) { + status->status = tensorflow::errors::Internal( + "Invalid text proto for ServerDef: ", text_proto); + return nullptr; + } + status->status = tensorflow::Status(); + TF_Buffer* ret = TF_NewBuffer(); + TF_CHECK_OK(MessageToBuffer(server_def, ret)); + return ret; +} + +TFE_Context* TFE_CreateContextFromSession(TF_Session* session, + TF_Status* status) { + auto* opts = TFE_NewContextOptions(); + + // Reduce GPU memory allocation, and set appropriate config options for TFE + // context. + auto* config = + TF_CreateConfig(/*xla*/ false, /* gpu_memory_allow_growth */ true); + TFE_ContextOptionsSetConfig(opts, config->data, config->length, status); + if (!status->status.ok()) { + CHECK(!config); + TFE_DeleteContextOptions(opts); + return nullptr; + } + + auto* ctx = TFE_NewContextFromSession(opts, session, status); + TF_DeleteBuffer(config); + TFE_DeleteContextOptions(opts); + return ctx; +} + +// TODO: retrieve the device string via TFE_ContextListDevices() +static const char DEFAULT_CPU_DEVICE[] = + "/job:localhost/replica:0/task:0/device:CPU:0"; + +static TFE_TensorHandle* createTFEQueue(TFE_Context* ctx, TF_DataType inputType, + int tensor_id, TF_Status* status) { + std::unique_ptr queueOp( + TFE_NewOp(ctx, "FIFOQueueV2", status), TFE_DeleteOp); + TFE_OpSetDevice(queueOp.get(), DEFAULT_CPU_DEVICE, status); + if (!status->status.ok()) return nullptr; + // TODO: use NAMED_TENSOR_QUEUE_CAPACITY in S4TF compiler. + TFE_OpSetAttrInt(queueOp.get(), "capacity", 1); + TFE_OpSetAttrTypeList(queueOp.get(), "component_types", &inputType, 1); + auto shared_name = tensorflow::strings::StrCat("fifo_queue_", tensor_id); + TFE_OpSetAttrString(queueOp.get(), "shared_name", shared_name.data(), + shared_name.size()); + TFE_OpSetAttrString(queueOp.get(), "container", "", 0); + + // TODO: consider making this an unknown shape. + const int64_t* dims_ptr = nullptr; + int num_dims = 0; + TFE_OpSetAttrShapeList(queueOp.get(), "shapes", &dims_ptr, &num_dims, + /*num_values*/ 0, status); + if (!status->status.ok()) return nullptr; + + int num_retvals = 1; + TFE_TensorHandle* queue = nullptr; + TFE_Execute(queueOp.get(), &queue, &num_retvals, status); + if (!status->status.ok()) return nullptr; + CHECK_EQ(num_retvals, 1); + + return queue; +} + +static void createTFEEnqueue(TFE_Context* ctx, TF_DataType inputType, + TFE_TensorHandle* queue, TFE_TensorHandle* tensor, + TF_Status* status) { + TFE_Op* op = TFE_NewOp(ctx, "QueueEnqueueV2", status); + if (!status->status.ok()) return; + std::unique_ptr op_deleter(op, TFE_DeleteOp); + TFE_OpSetDevice(op, DEFAULT_CPU_DEVICE, status); + if (!status->status.ok()) return; + TFE_OpAddInput(op, queue, status); + if (!status->status.ok()) return; + TFE_OpAddInput(op, tensor, status); + if (!status->status.ok()) return; + TFE_OpSetAttrTypeList(op, "Tcomponents", &inputType, 1); + TFE_OpSetAttrInt(op, "timeout_ms", -1); + + int num_retvals = 0; + TFE_Execute(op, nullptr /*retvals*/, &num_retvals, status); + if (!status->status.ok()) return; + CHECK_EQ(num_retvals, 0); +} + +static TFE_TensorHandle* createTFEDequeue(TFE_Context* ctx, + TF_DataType inputType, + TFE_TensorHandle* queue, + TF_Status* status) { + TFE_Op* op = TFE_NewOp(ctx, "QueueDequeueV2", status); + if (!status->status.ok()) return nullptr; + std::unique_ptr op_deleter(op, TFE_DeleteOp); + TFE_OpSetDevice(op, DEFAULT_CPU_DEVICE, status); + if (!status->status.ok()) return nullptr; + + TFE_OpAddInput(op, queue, status); + if (!status->status.ok()) return nullptr; + TFE_OpSetAttrTypeList(op, "component_types", &inputType, 1); + TFE_OpSetAttrInt(op, "timeout_ms", -1); + TFE_TensorHandle* ret; + int num_retvals = 1; + TFE_Execute(op, &ret, &num_retvals, status); + if (!status->status.ok()) return nullptr; + CHECK_EQ(num_retvals, 1); + return ret; +} + +TFE_TensorHandle* TFE_DequeueNamedTensor(TF_Session* session, int tensor_id, + TF_DataType inputType, + TF_Status* status) { + assert(session); + VLOG(1) << "Dequeuing data tensor with id " << tensor_id; + + auto ctx = TFE_CreateContextFromSession(session, status); + if (!status->status.ok()) return nullptr; + std::unique_ptr ctx_deleter( + ctx, TFE_DeleteContext); + + TFE_TensorHandle* queue = createTFEQueue(ctx, inputType, tensor_id, status); + if (!status->status.ok()) return nullptr; + std::unique_ptr + queue_deleter(queue, TFE_DeleteTensorHandle); + + auto* ret = createTFEDequeue(ctx, inputType, queue, status); + return ret; +} + +TFE_TensorHandle* TFE_DequeueNamedTensorFromCtx(TFE_Context* ctx, int tensor_id, + TF_DataType inputType, + TF_Status* status) { + TFE_TensorHandle* queue = createTFEQueue(ctx, inputType, tensor_id, status); + if (!status->status.ok()) return nullptr; + std::unique_ptr + queue_deleter(queue, TFE_DeleteTensorHandle); + + auto* ret = createTFEDequeue(ctx, inputType, queue, status); + + return ret; +} + +void TFE_EnqueueNamedTensor(TF_Session* session, int tensor_id, + TFE_TensorHandle* tensor, TF_Status* status) { + assert(session); + VLOG(1) << "Enqueuing data tensor with id " << tensor_id; + + auto ctx = TFE_CreateContextFromSession(session, status); + if (!status->status.ok()) return; + std::unique_ptr ctx_deleter( + ctx, TFE_DeleteContext); + + TF_DataType inputType = TFE_TensorHandleDataType(tensor); + TFE_TensorHandle* queue = createTFEQueue(ctx, inputType, tensor_id, status); + if (!status->status.ok()) return; + std::unique_ptr + queue_deleter(queue, TFE_DeleteTensorHandle); + + createTFEEnqueue(ctx, inputType, queue, tensor, status); +} + +void TFE_EnqueueNamedTensorFromCtx(TFE_Context* ctx, int tensor_id, + TFE_TensorHandle* tensor, + TF_Status* status) { + VLOG(1) << "Enqueuing data tensor with id " << tensor_id; + + TF_DataType inputType = TFE_TensorHandleDataType(tensor); + TFE_TensorHandle* queue = createTFEQueue(ctx, inputType, tensor_id, status); + if (!status->status.ok()) return; + std::unique_ptr + queue_deleter(queue, TFE_DeleteTensorHandle); + + createTFEEnqueue(ctx, inputType, queue, tensor, status); +} + +void TFE_EnqueueVariantTensor(TF_Session* session, int tensor_id, + TFE_TensorHandle* tensor, TF_Status* status) { + VLOG(1) << "Enqueuing variant tensor with id " << tensor_id; + + auto ctx = TFE_CreateContextFromSession(session, status); + if (!status->status.ok()) return; + std::unique_ptr ctx_deleter( + ctx, TFE_DeleteContext); + + TFE_TensorHandle* queue = createTFEQueue(ctx, TF_VARIANT, tensor_id, status); + if (!status->status.ok()) return; + std::unique_ptr + queue_deleter(queue, TFE_DeleteTensorHandle); + + createTFEEnqueue(ctx, TF_VARIANT, queue, tensor, status); +} + +TFE_TensorHandle* TFE_DequeueVariantTensor(TF_Session* session, int tensor_id, + TF_Status* status) { + VLOG(1) << "Dequeuing variant tensor with id " << tensor_id; + + auto ctx = TFE_CreateContextFromSession(session, status); + if (!status->status.ok()) return nullptr; + std::unique_ptr ctx_deleter( + ctx, TFE_DeleteContext); + + TFE_TensorHandle* queue = createTFEQueue(ctx, TF_VARIANT, tensor_id, status); + if (!status->status.ok()) return nullptr; + std::unique_ptr + queue_deleter(queue, TFE_DeleteTensorHandle); + + return createTFEDequeue(ctx, TF_VARIANT, queue, status); +} + +static void CheckOk(TF_Status* status) { + CHECK_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); +} + +void TFE_TensorHandlePrintDebugString(TFE_TensorHandle* handle) { + auto* status = TF_NewStatus(); + TF_Tensor* t = TFE_TensorHandleResolve(handle, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + tensorflow::Tensor dst; + TF_CHECK_OK(TF_TensorToTensor(t, &dst)); + LOG(INFO) << dst.DebugString(); + + TF_DeleteTensor(t); + TF_DeleteStatus(status); +} + +TF_CAPI_EXPORT extern void TF_MakeInternalErrorStatus(TF_Status* status, + const char* errMsg) { + status->status = tensorflow::errors::Internal(errMsg); +} diff --git a/tensorflow/c/c_api_experimental.h b/tensorflow/c/c_api_experimental.h index 6617c5a572e90e78369f73d714f39942f213040f..d98d532e32e891e21f5b7ba360c74c3256fb1947 100644 --- a/tensorflow/c/c_api_experimental.h +++ b/tensorflow/c/c_api_experimental.h @@ -20,6 +20,7 @@ limitations under the License. #include #include "tensorflow/c/c_api.h" +#include "tensorflow/c/eager/c_api.h" // -------------------------------------------------------------------------- // Experimental C API for TensorFlow. @@ -130,6 +131,57 @@ TF_CAPI_EXPORT extern void TF_EnqueueNamedTensor(TF_Session* session, int tensor_id, TF_Tensor* tensor, TF_Status* status); +// Create a serialized tensorflow.ServerDef proto. +TF_Buffer* TFE_GetServerDef(const char* text_proto, TF_Status* status); + +// TODO: remove this API in favor of the next one. +TF_CAPI_EXPORT extern TFE_Context* TFE_NewContextFromSession( + const TFE_ContextOptions* opts, TF_Session* sess, TF_Status* status); + +// Creates from `session` a new eager context to run a graph function or +// sends/recvs, so that these concurrent TFE executions can share (via +// `session` and its associated device mgr) the same set of fifo queue resource +// ops, used for host<->TF tensor transfers. This way the sends/recvs calls and +// graph function execution can access the same fifo queue resource handles +// (associated with devices managed by the device manager, which can be obtained +// from `session`). +// +// TODO: Remove this function once we migrate away from using session. +TF_CAPI_EXPORT extern TFE_Context* TFE_CreateContextFromSession( + TF_Session* session, TF_Status* status); + +// TODO: Retire this API in favor of the next one. +TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_DequeueNamedTensor( + TF_Session* session, int tensor_id, TF_DataType inputType, + TF_Status* status); + +TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_DequeueNamedTensorFromCtx( + TFE_Context* ctx, int tensor_id, TF_DataType inputType, TF_Status* status); + +TF_CAPI_EXPORT extern void TFE_EnqueueNamedTensor(TF_Session* session, + int tensor_id, + TFE_TensorHandle* tensor, + TF_Status* status); + +TF_CAPI_EXPORT extern void TFE_EnqueueNamedTensorFromCtx( + TFE_Context* ctx, int tensor_id, TFE_TensorHandle* tensor, + TF_Status* status); + +// TODO: consider folding the 2 APIs below into the ones above. +TF_CAPI_EXPORT extern void TFE_EnqueueVariantTensor(TF_Session* session, + int tensor_id, + TFE_TensorHandle* tensor, + TF_Status* status); + +TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_DequeueVariantTensor( + TF_Session* session, int tensor_id, TF_Status* status); + +// Prints `handle` in a human readable format to standard output for debugging. +TF_CAPI_EXPORT extern void TFE_TensorHandlePrintDebugString( + TFE_TensorHandle* handle); + +TF_CAPI_EXPORT extern void TF_MakeInternalErrorStatus(TF_Status* status, + const char* errMsg); #ifdef __cplusplus } /* end extern "C" */ diff --git a/tensorflow/c/c_api_experimental_test.cc b/tensorflow/c/c_api_experimental_test.cc index 30fcfd401d9d634962d64aaa3bf348de91f2ecae..c6effd39697e0397278770b53e98508074f99862 100644 --- a/tensorflow/c/c_api_experimental_test.cc +++ b/tensorflow/c/c_api_experimental_test.cc @@ -16,8 +16,10 @@ limitations under the License. #include "tensorflow/c/c_api_experimental.h" #include "tensorflow/c/c_test_util.h" #include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/test.h" +#include "tensorflow/core/protobuf/tensorflow_server.pb.h" namespace tensorflow { namespace { @@ -116,5 +118,49 @@ TEST(CAPI_EXPERIMENTAL, ImagenetIteratorGetNext) { TF_DeleteStatus(s); } +TEST(CAPI_EXPERIMENTAL, GetServerDefTest) { + const string expected_text_proto(R"(cluster { + job { + name: "worker" + tasks { + key: 0 + value: "tpuserver:0" + } + tasks { + key: 1 + value: "localhost:1" + } + } +} +job_name: "worker" +task_index: 1 +protocol: "grpc" +)"); + + TF_Status* status = TF_NewStatus(); + TF_Buffer* result = TFE_GetServerDef(expected_text_proto.c_str(), status); + EXPECT_EQ(TF_GetCode(status), TF_OK); + + ServerDef actual; + ASSERT_TRUE(actual.ParseFromArray(result->data, result->length)); + string actual_text_proto; + tensorflow::protobuf::TextFormat::PrintToString(actual, &actual_text_proto); + EXPECT_EQ(expected_text_proto, actual_text_proto); + + const string malformed_text_proto(R"(cluster { + job { + name: "worker")"); + TF_Buffer* null_result = + TFE_GetServerDef(malformed_text_proto.c_str(), status); + EXPECT_NE(TF_GetCode(status), TF_OK); + EXPECT_TRUE(tensorflow::str_util::StrContains( + TF_Message(status), "Invalid text proto for ServerDef")); + EXPECT_EQ(null_result, nullptr); + + // Cleanup + TF_DeleteBuffer(result); + TF_DeleteStatus(status); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/c/c_api_function.cc b/tensorflow/c/c_api_function.cc index a2c5a42c11361779de61b515e0f08dcc45e609b9..f68f8a3e90a971b5e4a024feaf26ba498afc48da 100644 --- a/tensorflow/c/c_api_function.cc +++ b/tensorflow/c/c_api_function.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/tensor.pb.h" // NOLINT #include "tensorflow/core/framework/types.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/lib/strings/base64.h" diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc index aa2a537f03be31ae45ff3d6f7815b449d661cf9c..03516c39dc970aa23967107d3a0446da94669465 100644 --- a/tensorflow/c/c_api_test.cc +++ b/tensorflow/c/c_api_test.cc @@ -259,8 +259,8 @@ TEST(CAPI, DeprecatedSession) { TF_Run(session, run_options, nullptr, nullptr, 0, nullptr, nullptr, 0, nullptr, 0, run_metadata, s); EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s)) << TF_Message(s); - EXPECT_EQ(std::string("Session was not created with a graph before Run()!"), - std::string(TF_Message(s))); + EXPECT_EQ("Session was not created with a graph before Run()!", + string(TF_Message(s))); TF_DeleteBuffer(run_metadata); TF_DeleteBuffer(run_options); @@ -1224,8 +1224,8 @@ class CApiColocationTest : public ::testing::Test { TF_OperationGetAttrMetadata(op, tensorflow::kColocationAttrName, s_); if (expected.empty()) { ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)) << TF_Message(s_); - EXPECT_EQ(std::string("Operation 'add' has no attr named '_class'."), - std::string(TF_Message(s_))); + EXPECT_EQ("Operation 'add' has no attr named '_class'.", + string(TF_Message(s_))); return; } EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); @@ -1369,16 +1369,16 @@ TEST(CAPI, SavedModel) { input.flat()(i) = example.SerializeAsString(); } - const tensorflow::string input_op_name = - std::string(tensorflow::ParseTensorName(input_name).first); + const tensorflow::string input_op_name( + tensorflow::ParseTensorName(input_name).first); TF_Operation* input_op = TF_GraphOperationByName(graph, input_op_name.c_str()); ASSERT_TRUE(input_op != nullptr); csession.SetInputs({{input_op, TF_TensorFromTensor(input, s)}}); ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); - const tensorflow::string output_op_name = - std::string(tensorflow::ParseTensorName(output_name).first); + const tensorflow::string output_op_name( + tensorflow::ParseTensorName(output_name).first); TF_Operation* output_op = TF_GraphOperationByName(graph, output_op_name.c_str()); ASSERT_TRUE(output_op != nullptr); diff --git a/tensorflow/c/checkpoint_reader.cc b/tensorflow/c/checkpoint_reader.cc index 74bc25a491ac01cb725d1c004197e48727c30230..d3311f0cd06f2b151c3567735eb41b5baf72e102 100644 --- a/tensorflow/c/checkpoint_reader.cc +++ b/tensorflow/c/checkpoint_reader.cc @@ -125,7 +125,7 @@ CheckpointReader::BuildV2VarMaps() { const auto& slice_proto = entry.slices(i); CHECK(filtered_keys .insert(EncodeTensorNameSlice( - std::string(v2_reader_->key()) /* full var's name */, + string(v2_reader_->key()) /* full var's name */, TensorSlice(slice_proto))) .second); } @@ -138,11 +138,11 @@ CheckpointReader::BuildV2VarMaps() { new TensorSliceReader::VarToDataTypeMap); v2_reader_->Seek(kHeaderEntryKey); for (v2_reader_->Next(); v2_reader_->Valid(); v2_reader_->Next()) { - if (filtered_keys.count(std::string(v2_reader_->key())) > 0) continue; + if (filtered_keys.count(string(v2_reader_->key())) > 0) continue; CHECK(entry.ParseFromArray(v2_reader_->value().data(), v2_reader_->value().size())) << entry.InitializationErrorString(); - string key = std::string(v2_reader_->key()); + string key(v2_reader_->key()); (*var_to_shape_map)[key] = TensorShape(entry.shape()); (*var_to_data_type_map)[key] = DataType(entry.dtype()); } diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index 37be52f57d865c1e59611540d5dab04b59e89444..3ee31a6a7ac641bbd3fc4c05568b61e433a1d523 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -68,7 +68,10 @@ tf_cuda_library( tf_cuda_library( name = "c_api_internal", hdrs = ["c_api_internal.h"], - visibility = ["//tensorflow:internal"], + visibility = [ + "//learning/deepmind/courier:__pkg__", + "//tensorflow:internal", + ], deps = [ ":c_api", "//tensorflow/c:c_api", diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc old mode 100644 new mode 100755 index dfb1c9a37644c726e1eabab775593596d5b556b9..3554ec0bf3202b54bfc38d67e51b89df19832302 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -244,8 +244,8 @@ void TFE_ContextOptionsSetConfig(TFE_ContextOptions* options, const void* proto, } void TFE_ContextOptionsSetAsync(TFE_ContextOptions* options, - unsigned char async) { - options->async = async; + unsigned char enable) { + options->async = enable; } void TFE_ContextOptionsSetDevicePlacementPolicy( TFE_ContextOptions* options, TFE_ContextDevicePlacementPolicy policy) { @@ -253,9 +253,9 @@ void TFE_ContextOptionsSetDevicePlacementPolicy( } TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context* ctx, - unsigned char async, + unsigned char enable, TF_Status* status) { - status->status = ctx->context.SetAsyncForThread(async); + status->status = ctx->context.SetAsyncForThread(enable); } void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; } @@ -273,7 +273,20 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) { new tensorflow::IntraProcessRendezvous(device_mgr.get()); return new TFE_Context(opts->session_options.options, opts->policy, - opts->async, std::move(device_mgr), r); + opts->async, device_mgr.release(), + /*device_mgr_owned*/ true, r); +} + +TFE_Context* TFE_NewContextFromSession(const TFE_ContextOptions* opts, + TF_Session* sess, TF_Status* status) { + const tensorflow::DeviceMgr* device_mgr = nullptr; + status->status = sess->session->LocalDeviceManager(&device_mgr); + if (!status->status.ok()) return nullptr; + tensorflow::Rendezvous* r = + new tensorflow::IntraProcessRendezvous(device_mgr); + return new TFE_Context(opts->session_options.options, opts->policy, + opts->async, device_mgr, /*device_mgr_owned*/ false, + r); } void TFE_DeleteContext(TFE_Context* ctx) { delete ctx; } @@ -362,6 +375,17 @@ int TFE_TensorHandleNumDims(TFE_TensorHandle* h, TF_Status* status) { return result; } +int64_t TFE_TensorHandleNumElements(TFE_TensorHandle* h, TF_Status* status) { + if (h == nullptr || h->handle == nullptr) { + status->status = tensorflow::errors::InvalidArgument( + "The passed in handle is a nullptr"); + return -1; + } + tensorflow::int64 result; + status->status = h->handle->NumElements(&result); + return result; +} + int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index, TF_Status* status) { if (h == nullptr || h->handle == nullptr) { @@ -386,6 +410,19 @@ const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h, TF_Status* status) { : d->name().c_str(); } +TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_TensorHandleCopySharingTensor( + TFE_TensorHandle* h, TF_Status* status) { + if (h == nullptr || h->handle == nullptr) { + status->status = tensorflow::errors::InvalidArgument( + "The passed in handle is a nullptr"); + return nullptr; + } + + h->handle->Ref(); + + return new TFE_TensorHandle(h->handle); +} + TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) { if (h == nullptr || h->handle == nullptr) { status->status = tensorflow::errors::InvalidArgument( @@ -541,6 +578,21 @@ void TFE_OpSetAttrFunction(TFE_Op* op, const char* attr_name, op->operation.MutableAttrs()->Set(attr_name, attr_value); } +void TFE_OpSetAttrFunctionName(TFE_Op* op, const char* attr_name, + const char* data, size_t length) { + tensorflow::AttrValue attr_value; + tensorflow::NameAttrList* func = attr_value.mutable_func(); + func->set_name(data, length); + op->operation.MutableAttrs()->Set(attr_name, attr_value); +} + +void TFE_OpSetAttrTensor(TFE_Op* op, const char* attr_name, TF_Tensor* tensor, + TF_Status* status) { + tensorflow::Tensor t; + status->status = TF_TensorToTensor(tensor, &t); + if (status->status.ok()) op->operation.MutableAttrs()->Set(attr_name, t); +} + void TFE_OpSetAttrStringList(TFE_Op* op, const char* attr_name, const void* const* values, const size_t* lengths, int num_values) { diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h old mode 100644 new mode 100755 index a0ebc6fa0a22ed61be91c2974352c2988fb4cd92..b2454d872207e26feb3764671474a5d87c01f84d --- a/tensorflow/c/eager/c_api.h +++ b/tensorflow/c/eager/c_api.h @@ -76,7 +76,7 @@ typedef enum TFE_ContextDevicePlacementPolicy { // Sets the default execution mode (sync/async). Note that this can be // overridden per thread using TFE_ContextSetAsyncForThread. TF_CAPI_EXPORT extern void TFE_ContextOptionsSetAsync(TFE_ContextOptions*, - unsigned char async); + unsigned char enable); TF_CAPI_EXPORT extern void TFE_ContextOptionsSetDevicePlacementPolicy( TFE_ContextOptions*, TFE_ContextDevicePlacementPolicy); @@ -114,7 +114,7 @@ TFE_ContextGetDevicePlacementPolicy(TFE_Context*); // Overrides the execution mode (sync/async) for the current thread. TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context*, - unsigned char async, + unsigned char enable, TF_Status* status); // A tensorflow.ServerDef specifies remote workers (in addition to the current @@ -163,6 +163,8 @@ TF_CAPI_EXPORT extern TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h); // This function will block till the operation that produces `h` has completed. TF_CAPI_EXPORT extern int TFE_TensorHandleNumDims(TFE_TensorHandle* h, TF_Status* status); +TF_CAPI_EXPORT extern int64_t TFE_TensorHandleNumElements(TFE_TensorHandle* h, + TF_Status* status); // This function will block till the operation that produces `h` has completed. TF_CAPI_EXPORT extern int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index, @@ -171,6 +173,12 @@ TF_CAPI_EXPORT extern int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, TF_CAPI_EXPORT extern const char* TFE_TensorHandleDeviceName( TFE_TensorHandle* h, TF_Status* status); +// Return a pointer to a new TFE_TensorHandle that shares the underlying tensor +// with `h`. On success, `status` is set to OK. On failure, `status` reflects +// the error and a nullptr is returned. +TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_TensorHandleCopySharingTensor( + TFE_TensorHandle* h, TF_Status* status); + // This function will block till the operation that produces `h` has // completed. The memory returned might alias the internal memory used by // TensorFlow. Hence, callers should not mutate this memory (for example by @@ -305,6 +313,14 @@ TF_CAPI_EXPORT extern void TFE_OpSetAttrFunction(TFE_Op* op, const char* attr_name, const TFE_Op* value); +TF_CAPI_EXPORT void TFE_OpSetAttrFunctionName(TFE_Op* op, const char* attr_name, + const char* data, size_t length); + +TF_CAPI_EXPORT extern void TFE_OpSetAttrTensor(TFE_Op* op, + const char* attr_name, + TF_Tensor* tensor, + TF_Status* status); + TF_CAPI_EXPORT extern void TFE_OpSetAttrStringList(TFE_Op* op, const char* attr_name, const void* const* values, diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h index a5c0681e2e4eddae08954d9d0178ca96a3f8f29a..104d52430cf7aa14d4d2a335a1b96e667f21ce87 100644 --- a/tensorflow/c/eager/c_api_internal.h +++ b/tensorflow/c/eager/c_api_internal.h @@ -62,15 +62,14 @@ struct TFE_ContextOptions { }; struct TFE_Context { - explicit TFE_Context(const tensorflow::SessionOptions& opts, - TFE_ContextDevicePlacementPolicy default_policy, - bool async, - std::unique_ptr device_mgr, - tensorflow::Rendezvous* rendezvous) + TFE_Context(const tensorflow::SessionOptions& opts, + TFE_ContextDevicePlacementPolicy default_policy, bool async, + const tensorflow::DeviceMgr* device_mgr, bool device_mgr_owned, + tensorflow::Rendezvous* rendezvous) : context(opts, static_cast( default_policy), - async, std::move(device_mgr), rendezvous) {} + async, device_mgr, device_mgr_owned, rendezvous) {} tensorflow::EagerContext context; }; diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc index 7126227cf529023eadf38984668a40118641bb1b..55331022b9dbd0696928fa44430f340f371432ac 100644 --- a/tensorflow/c/eager/c_api_test.cc +++ b/tensorflow/c/eager/c_api_test.cc @@ -1528,4 +1528,29 @@ TEST(CAPI, StringAttributes) { TFE_DeleteContext(ctx); TF_DeleteStatus(status); } + +TEST(CAPI, TestTFE_TensorHandleCopySharingUnderlyingTensorHandle) { + TFE_TensorHandle* h = TestMatrixTensorHandle(); + EXPECT_EQ(TF_FLOAT, TFE_TensorHandleDataType(h)); + + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + + TFE_TensorHandle* h_shares_tensor = + TFE_TensorHandleCopySharingTensor(h, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + + TF_Tensor* t = TFE_TensorHandleResolve(h_shares_tensor, status.get()); + ASSERT_EQ(16, TF_TensorByteSize(t)); + float data[4] = {0}; + memcpy(&data[0], TF_TensorData(t), TF_TensorByteSize(t)); + EXPECT_EQ(1.0, data[0]); + EXPECT_EQ(2.0, data[1]); + EXPECT_EQ(3.0, data[2]); + EXPECT_EQ(4.0, data[3]); + TF_DeleteTensor(t); + + TFE_DeleteTensorHandle(h); + TFE_DeleteTensorHandle(h_shares_tensor); +} } // namespace diff --git a/tensorflow/c/eager/c_api_test_util.cc b/tensorflow/c/eager/c_api_test_util.cc index 5607c9dcb0bbec72b2f86def3dd4e6590d73197b..008f088c2dcdd7d9114103516a4702e47a55c6de 100644 --- a/tensorflow/c/eager/c_api_test_util.cc +++ b/tensorflow/c/eager/c_api_test_util.cc @@ -99,8 +99,6 @@ TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b) { TFE_OpAddInput(op, b, status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TF_DeleteStatus(status); - TFE_OpSetAttrBool(op, "transpose_a", 0); - TFE_OpSetAttrBool(op, "transpose_b", 0); TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(a)); return op; diff --git a/tensorflow/c/eager/tape.h b/tensorflow/c/eager/tape.h index 1adb0458c35193117b5fa5cfe9ceffbaaf699af7..5ba55a203ff70cc64c07e96b5a869a1f11c9334e 100644 --- a/tensorflow/c/eager/tape.h +++ b/tensorflow/c/eager/tape.h @@ -29,15 +29,8 @@ limitations under the License. namespace tensorflow { namespace eager { -// Information about a tensor. -struct TapeTensor { - int64 id; // Expected to be unique in the lifetime of this process. - DataType dtype; - TensorShape shape; -}; - // Represents an entry in the tape. -template +template struct OpTapeEntry { string op_type; std::vector output_tensor_info; @@ -57,8 +50,8 @@ struct OpTapeEntry { using TensorTape = gtl::FlatMap; // Map from operation-id to tape entry. -template -using OpTape = gtl::FlatMap>; +template +using OpTape = gtl::FlatMap>; // Operations the tape needs to perform on tensors to do backpropagation. Named // "vspace" because a subset of these are related to a vector space, such as @@ -79,7 +72,7 @@ using OpTape = gtl::FlatMap>; // TODO(apassos) provide concrete template instantiations for TFE_TensorHandle // specialization, which is blocked by quite a few things needing to loop back // into python now. -template +template class VSpace { public: virtual ~VSpace() {} @@ -93,10 +86,10 @@ class VSpace { gtl::ArraySlice gradient_tensors) const = 0; // Returns a tensor of the right shape and dtype filled with zeros. - virtual Gradient* Zeros(TensorShape shape, DataType dtype) const = 0; + virtual Gradient* Zeros(const TapeTensor& tensor) const = 0; // Returns a Tensor which is filled with ones and like the input. - virtual Gradient* Ones(TensorShape shape, DataType dtype) const = 0; + virtual Gradient* Ones(const TapeTensor& tensor) const = 0; // Calls the passed-in backward function. virtual Status CallBackwardFunction( @@ -114,7 +107,7 @@ class VSpace { // Traces the execution of operations, doing eager garbage collection, and // exporting a full trace so other code can do backpropagation. Not thread-safe. -template +template class GradientTape { public: // If `persistent` is true, GradientTape will not eagerly delete backward @@ -134,10 +127,10 @@ class GradientTape { void Watch(int64 tensor_id); void RecordOperation( - const string& op_type, gtl::ArraySlice output_tensors, + const string& op_type, std::vector& output_tensors, gtl::ArraySlice input_tensor_id, gtl::ArraySlice input_dtypes, - BackwardFunction* backward_function, + const std::function& backward_function_getter, const std::function& backward_function_deleter); void DeleteTrace(int64 tensor_id); @@ -146,17 +139,18 @@ class GradientTape { // once) and produces the gradient of the target tensors with respect to the // source tensors. The output gradients are used if not empty and not // null. The result is populated with one tensor per target element. - Status ComputeGradient(const VSpace& vspace, - gtl::ArraySlice target_tensor_ids, - gtl::ArraySlice source_tensor_id, - gtl::ArraySlice output_gradients, - std::vector* result); + Status ComputeGradient( + const VSpace& vspace, + gtl::ArraySlice target_tensor_ids, + gtl::ArraySlice source_tensor_id, + gtl::ArraySlice output_gradients, + std::vector* result); bool IsPersistent() const { return persistent_; } private: TensorTape tensor_tape_; - OpTape op_tape_; + OpTape op_tape_; int64 next_op_id_{0}; // Map from tensor id to number of remaining usages (i.e. how many entries in @@ -186,8 +180,8 @@ inline bool IsDtypeTrainable(DataType dtype) { } } -template -bool GradientTape::ShouldRecord( +template +bool GradientTape::ShouldRecord( gtl::ArraySlice tensor_ids, gtl::ArraySlice dtypes) { CHECK_EQ(tensor_ids.size(), dtypes.size()); @@ -201,20 +195,20 @@ bool GradientTape::ShouldRecord( return false; } -template -void GradientTape::Watch(int64 tensor_id) { +template +void GradientTape::Watch( + int64 tensor_id) { tensor_tape_.emplace(tensor_id, -1); } -template -void GradientTape::RecordOperation( - const string& op_type, gtl::ArraySlice output_tensors, +template +void GradientTape::RecordOperation( + const string& op_type, std::vector& output_tensors, gtl::ArraySlice input_tensor_id, gtl::ArraySlice input_dtypes, - BackwardFunction* backward_function, + const std::function& backward_function_getter, const std::function& backward_function_deleter) { if (!ShouldRecord(input_tensor_id, input_dtypes)) { - backward_function_deleter(backward_function); return; } std::vector ids; @@ -229,16 +223,18 @@ void GradientTape::RecordOperation( for (const TapeTensor& o : output_tensors) { // Note: the tensor can have already been watched and hence be in the tape, // so we cannot check that we're inserting it here. - tensor_tape_[o.id] = op_id; - tensor_usage_[o.id] = 1; + tensor_tape_[o.GetID()] = op_id; + tensor_usage_[o.GetID()] = 1; tensors.push_back(o); } - op_tape_[op_id] = OpTapeEntry{ - op_type, tensors, ids, backward_function, backward_function_deleter}; + op_tape_[op_id] = OpTapeEntry{ + op_type, std::move(tensors), std::move(ids), backward_function_getter(), + backward_function_deleter}; } -template -void GradientTape::DeleteTrace(int64 tensor_id) { +template +void GradientTape::DeleteTrace( + int64 tensor_id) { auto it = tensor_usage_.find(tensor_id); if (it == tensor_usage_.end()) { return; @@ -261,7 +257,7 @@ void GradientTape::DeleteTrace(int64 tensor_id) { auto op_it = op_tape_.find(op_id); CHECK(op_it != op_tape_.end()); for (const auto& output : op_it->second.output_tensor_info) { - if (tensor_usage_.find(output.id) != tensor_usage_.end()) { + if (tensor_usage_.find(output.GetID()) != tensor_usage_.end()) { // Found a usage for an output, so cannot delete the op. return; } @@ -304,9 +300,9 @@ void GradientTape::DeleteTrace(int64 tensor_id) { namespace { -template +template struct BackpropInitialState { - OpTape op_tape; + OpTape op_tape; // Map from tensor ID to how many references still exist for this tensor in // the tape. @@ -322,17 +318,17 @@ struct BackpropInitialState { // If `persistent_tape` is false, op_tape is cleared and backwards functions // not needed for gradient computation are deleted. Backwards functions that // are needed, are copied and returned in BackpropInitialState. -template -BackpropInitialState PrepareBackprop( +template +BackpropInitialState PrepareBackprop( gtl::ArraySlice target, const TensorTape& tensor_tape, - OpTape* op_tape, const gtl::FlatSet& sources_set, - bool persistent_tape) { + OpTape* op_tape, + const gtl::FlatSet& sources_set, bool persistent_tape) { std::vector tensor_stack; tensor_stack.reserve(target.size()); for (auto t : target) { tensor_stack.push_back(t); } - BackpropInitialState result; + BackpropInitialState result; while (!tensor_stack.empty()) { int64 tensor_id = tensor_stack.back(); tensor_stack.pop_back(); @@ -383,9 +379,9 @@ BackpropInitialState PrepareBackprop( return result; } -template +template std::vector InitialStack( - const OpTape& op_tape, + const OpTape& op_tape, const gtl::FlatMap& op_missing_tensor) { std::vector result; for (auto& op_entry : op_tape) { @@ -396,13 +392,13 @@ std::vector InitialStack( return result; } -template -Status InitialGradients(const VSpace& vspace, - gtl::ArraySlice target_tensor_ids, - gtl::ArraySlice output_gradients, - const TensorTape& tensor_tape, - const OpTape& op_tape, - gtl::FlatMap>* result) { +template +Status InitialGradients( + const VSpace& vspace, + gtl::ArraySlice target_tensor_ids, + gtl::ArraySlice output_gradients, const TensorTape& tensor_tape, + const OpTape& op_tape, + gtl::FlatMap>* result) { for (int i = 0; i < target_tensor_ids.size(); ++i) { const int64 id = target_tensor_ids[i]; if (output_gradients.empty() || output_gradients[i] == nullptr) { @@ -416,11 +412,10 @@ Status InitialGradients(const VSpace& vspace, } bool found = false; for (int j = 0; j < op_it->second.output_tensor_info.size(); ++j) { - if (op_it->second.output_tensor_info[j].id == id) { + if (op_it->second.output_tensor_info[j].GetID() == id) { found = true; (*result)[id].push_back( - vspace.Ones(op_it->second.output_tensor_info[j].shape, - op_it->second.output_tensor_info[j].dtype)); + vspace.Ones(op_it->second.output_tensor_info[j])); break; } } @@ -440,6 +435,27 @@ Status InitialGradients(const VSpace& vspace, return Status::OK(); } +// TODO(agarwal): use an automatic mechanism for handling None arguments to +// gradient functions. +// +// Some gradient functions can accept None arguments for gradients. The +// following maps the operation name to the indices at which the corresponding +// gradient function can accept None values. e.g. FusedBatchNorm outputs 5 +// values and hence receives 5 gradient values during backprop. However the +// gradient function uses only the first of those values and ignores the rest. +// The entry, "FusedBatchNorm": [1, 2, 3, 4], indicates that only the gradient +// corresponding to index 0 is used, and the gradient values at indices 1-4 are +// ignored (and hence can be None). The backprop algorithm can then leverage +// this by not constructing zeros to pass for those indices. +gtl::FlatMap>* FunctionsAcceptingNoneForIndicesMap() { + static auto* const m = new gtl::FlatMap>({ + {"SoftmaxCrossEntropyWithLogits", {1}}, + {"SparseSoftmaxCrossEntropyWithLogits", {1}}, + {"FusedBatchNorm", {1, 2, 3, 4}}, + }); + return m; +} + } // namespace // If over kMinAggregateCount gradients are accumulated and the total @@ -448,16 +464,16 @@ Status InitialGradients(const VSpace& vspace, constexpr int kMinAggregateCount = 4; constexpr int kMinAggregateBytes = 128 * 1024 * 1024; -template -Status GradientTape::ComputeGradient( - const VSpace& vspace, +template +Status GradientTape::ComputeGradient( + const VSpace& vspace, gtl::ArraySlice target_tensor_ids, gtl::ArraySlice source_tensor_ids, gtl::ArraySlice output_gradients, std::vector* result) { gtl::FlatSet sources_set(source_tensor_ids.begin(), source_tensor_ids.end()); - BackpropInitialState state = PrepareBackprop( + BackpropInitialState state = PrepareBackprop( target_tensor_ids, tensor_tape_, &op_tape_, sources_set, persistent_); std::vector op_stack = InitialStack(state.op_tape, state.op_missing_tensor); @@ -485,10 +501,6 @@ Status GradientTape::ComputeGradient( VLOG(1) << " " << t; } } - gtl::FlatMap> functions_accept_none_for_indices({ - {"SoftmaxCrossEntropyWithLogits", {1}}, - {"FusedBatchNorm", {1, 2, 3, 4}}, - }); while (!op_stack.empty()) { const int64 op = op_stack.back(); VLOG(1) << "Popped " << op; @@ -505,18 +517,16 @@ Status GradientTape::ComputeGradient( out_gradients.reserve(trace.output_tensor_info.size()); bool any_gradient_nonzero = false; for (int i = 0; i < trace.output_tensor_info.size(); ++i) { - const int64 id = trace.output_tensor_info[i].id; + const int64 id = trace.output_tensor_info[i].GetID(); auto grad_it = gradients.find(id); if (grad_it == gradients.end()) { auto func_name_it = - functions_accept_none_for_indices.find(trace.op_type); - if (func_name_it != functions_accept_none_for_indices.end() && + FunctionsAcceptingNoneForIndicesMap()->find(trace.op_type); + if (func_name_it != FunctionsAcceptingNoneForIndicesMap()->end() && func_name_it->second.find(i) != func_name_it->second.end()) { out_gradients.push_back(nullptr); } else { - out_gradients.push_back( - vspace.Zeros(trace.output_tensor_info[i].shape, - trace.output_tensor_info[i].dtype)); + out_gradients.push_back(vspace.Zeros(trace.output_tensor_info[i])); } } else { any_gradient_nonzero = true; diff --git a/tensorflow/c/python_api.cc b/tensorflow/c/python_api.cc index 8486b585c8587e18e8eea18a893fac0a40ff4a27..247236b760dd8c07bbb08426100b6a4d34296d2e 100644 --- a/tensorflow/c/python_api.cc +++ b/tensorflow/c/python_api.cc @@ -110,7 +110,7 @@ void ExtendSession(TF_Session* session, TF_Status* status) { session->extend_before_run = false; } -std::string GetResourceHandleShapeAndType(TF_Graph* graph, TF_Output output) { +std::string GetHandleShapeAndType(TF_Graph* graph, TF_Output output) { Node* node = &output.oper->node; CppShapeInferenceResult::HandleData handle_data; handle_data.set_is_set(true); @@ -135,9 +135,8 @@ std::string GetResourceHandleShapeAndType(TF_Graph* graph, TF_Output output) { return result; } -void SetResourceHandleShapeAndType(TF_Graph* graph, TF_Output output, - const void* proto, size_t proto_len, - TF_Status* status) { +void SetHandleShapeAndType(TF_Graph* graph, TF_Output output, const void* proto, + size_t proto_len, TF_Status* status) { tensorflow::CppShapeInferenceResult::HandleData handle_data; if (!handle_data.ParseFromArray(proto, proto_len)) { status->status = tensorflow::errors::InvalidArgument( diff --git a/tensorflow/c/python_api.h b/tensorflow/c/python_api.h index 4bcb5bde62c8a4df4e68c1ce0daaf459434ceb5d..5cce84020bc68d912d259f51512341eb5f464a2c 100644 --- a/tensorflow/c/python_api.h +++ b/tensorflow/c/python_api.h @@ -54,16 +54,17 @@ void SetRequireShapeInferenceFns(TF_Graph* graph, bool require); void ExtendSession(TF_Session* session, TF_Status* status); // Returns the serialized CppShapeInferenceResult::HandleData proto for -// `output` if its a resource tensor, or otherwise returns the empty string. -std::string GetResourceHandleShapeAndType(TF_Graph* graph, TF_Output output); +// `output` if its a resource or variant tensor, or otherwise returns the empty +// string. +std::string GetHandleShapeAndType(TF_Graph* graph, TF_Output output); // Sets `output` based on `proto`, which should be a serialized -// CppShapeInferenceResult::HandleData proto. +// CppShapeInferenceResult::HandleData proto. `output` should be a resource +// or variant tensor. // NOTE(skyewm): `proto` is passed a void*/size_t pair instead of a std::string // because I couldn't get SWIG to work otherwise. -void SetResourceHandleShapeAndType(TF_Graph* graph, TF_Output output, - const void* proto, size_t proto_len, - TF_Status* status); +void SetHandleShapeAndType(TF_Graph* graph, TF_Output output, const void* proto, + size_t proto_len, TF_Status* status); } // namespace tensorflow #endif // TENSORFLOW_C_PYTHON_API_H_ diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD index f56521dac0374849081fe94f16feb08e55647b56..c18b07603ae3841d3581741ab5a43f2e8b628356 100644 --- a/tensorflow/cc/BUILD +++ b/tensorflow/cc/BUILD @@ -10,11 +10,12 @@ licenses(["notice"]) # Apache 2.0 load( "//tensorflow:tensorflow.bzl", - "tf_cc_test", + "cc_library_with_android_deps", "tf_cc_binary", + "tf_cc_test", "tf_copts", "tf_gen_op_wrappers_cc", - "cc_library_with_android_deps", + "transitive_hdrs", ) cc_library( @@ -410,6 +411,7 @@ tf_cc_test( srcs = ["gradients/nn_grad_test.cc"], deps = [ ":cc_ops", + ":cc_ops_internal", ":grad_op_registry", ":grad_testutil", ":gradient_checker", @@ -452,11 +454,33 @@ tf_cc_test( ], ) +# Generates separate libraries for array_ops and math_ops to reduce the dependency count of targets that depend on only these tf_gen_op_wrappers_cc( - name = "cc_ops", + name = "math_ops", + api_def_srcs = ["//tensorflow/core/api_def:base_api_def"], + op_lib_names = [ + "math_ops", + ], + pkg = "//tensorflow/core", +) + +tf_gen_op_wrappers_cc( + name = "array_ops", api_def_srcs = ["//tensorflow/core/api_def:base_api_def"], op_lib_names = [ "array_ops", + ], + pkg = "//tensorflow/core", +) + +tf_gen_op_wrappers_cc( + name = "cc_ops", + api_def_srcs = ["//tensorflow/core/api_def:base_api_def"], + deps_internal = [ + ":array_ops_internal", + ":math_ops_internal", + ], + op_lib_names = [ "audio_ops", "candidate_sampling_ops", "control_flow_ops", @@ -467,7 +491,6 @@ tf_gen_op_wrappers_cc( "logging_ops", "lookup_ops", "manip_ops", - "math_ops", "nn_ops", "no_op", "parsing_ops", @@ -479,10 +502,21 @@ tf_gen_op_wrappers_cc( "user_ops", ], other_hdrs = [ + "ops/array_ops.h", "ops/const_op.h", + "ops/math_ops.h", "ops/standard_ops.h", ], + other_hdrs_internal = [ + "ops/array_ops_internal.h", + "ops/math_ops_internal.h", + ], pkg = "//tensorflow/core", + deps = [ + ":array_ops", + ":const_op", + ":math_ops", + ], ) tf_cc_test( @@ -716,3 +750,26 @@ tf_cc_test( "//tensorflow/core:testlib", ], ) + +transitive_hdrs( + name = "headers", + visibility = ["//tensorflow:__subpackages__"], + deps = [ + ":cc_ops", + ":client_session", + ":coordinator", + ":gradient_checker", + ":gradients", + ":ops", + ":queue_runner", + ":remote_fused_graph_ops", + ":scope", + "//tensorflow/cc/profiler", + "//tensorflow/cc/saved_model:constants", + "//tensorflow/cc/saved_model:loader", + "//tensorflow/cc/saved_model:reader", + "//tensorflow/cc/saved_model:signature_constants", + "//tensorflow/cc/saved_model:tag_constants", + "//tensorflow/cc/tools:freeze_saved_model", + ], +) diff --git a/tensorflow/cc/framework/cc_op_gen.cc b/tensorflow/cc/framework/cc_op_gen.cc index c20ea95a15e3f53b9b26716ed7b624fa853017c9..39593370d1c243e84dc5b6091724d1d404c102b0 100644 --- a/tensorflow/cc/framework/cc_op_gen.cc +++ b/tensorflow/cc/framework/cc_op_gen.cc @@ -466,7 +466,7 @@ string AvoidCPPKeywords(StringPiece name) { if (IsCPPKeyword(name)) { return strings::StrCat(name, "_"); } - return std::string(name); + return string(name); } void InferArgAttributes(const OpDef::ArgDef& arg, @@ -853,11 +853,7 @@ void OpInfo::WriteClassDecl(WritableFile* h) const { } } - strings::StrAppend(&class_decl, "\n"); - - if (output_types.empty()) { - strings::StrAppend(&class_decl, " Operation operation;\n"); - } + strings::StrAppend(&class_decl, "\n Operation operation;\n"); for (int i = 0; i < output_types.size(); ++i) { strings::StrAppend(&class_decl, " ", output_types[i], " ", output_names[i], ";\n"); @@ -878,9 +874,11 @@ void OpInfo::GetOutput(string* out) const { string return_on_error = strings::StrCat("if (!", scope_str, ".ok()) return;"); + strings::StrAppend(out, " this->operation = Operation(ret);\n"); + // No outputs. if (graph_op_def.output_arg_size() == 0) { - strings::StrAppend(out, " this->operation = Operation(ret);\n return;\n"); + strings::StrAppend(out, " return;\n"); return; } if (graph_op_def.output_arg_size() == 1) { diff --git a/tensorflow/cc/framework/ops.h b/tensorflow/cc/framework/ops.h index a085e1d6e2de5ad63d11eb8979ae64c26b91366f..0717e7dd4b358d6c212070374bcc3fd2f91ed0ab 100644 --- a/tensorflow/cc/framework/ops.h +++ b/tensorflow/cc/framework/ops.h @@ -150,7 +150,7 @@ class Input { Initializer(const std::initializer_list& v, const TensorShape& shape) { typedef typename RealType::type RealT; Tensor t(DataTypeToEnum::v(), shape); - if (t.NumElements() != v.size()) { + if (t.NumElements() != static_cast(v.size())) { status = errors::InvalidArgument( "Cannot construct a tensor with ", t.NumElements(), " from an initializer list with ", v.size(), " elements"); diff --git a/tensorflow/cc/framework/scope.cc b/tensorflow/cc/framework/scope.cc index 8c886f31711eb014fb9e9d600c9c78cf22073f71..6abc9e268e3ac97379954a34017ddffa010db67f 100644 --- a/tensorflow/cc/framework/scope.cc +++ b/tensorflow/cc/framework/scope.cc @@ -62,7 +62,7 @@ Scope::Impl::Impl(const std::shared_ptr& graph, refiner_(refiner), scope_used_(nullptr), colocation_constraints_(), - disable_shape_inference_(false) {} + disable_shape_inference_(refiner_ == nullptr) {} Scope Scope::NewRootScope() { Graph* graph = new Graph(OpRegistry::Global()); @@ -94,6 +94,7 @@ Scope::Impl::Impl(const Scope& other, Tags::ScopeName, const string& name, exit_on_error_(other.impl()->exit_on_error_), kernel_label_(other.impl()->kernel_label_), device_(other.impl()->device_), + assigned_device_(other.impl()->assigned_device_), colocation_constraints_(other.impl()->colocation_constraints_), disable_shape_inference_(other.impl()->disable_shape_inference_) {} @@ -110,6 +111,7 @@ Scope::Impl::Impl(const Scope& other, Tags::OpName, const string& name, exit_on_error_(other.impl()->exit_on_error_), kernel_label_(other.impl()->kernel_label_), device_(other.impl()->device_), + assigned_device_(other.impl()->assigned_device_), colocation_constraints_(other.impl()->colocation_constraints_), disable_shape_inference_(other.impl()->disable_shape_inference_) {} @@ -132,6 +134,7 @@ Scope::Impl::Impl(const Scope& other, Tags::ControlDeps, exit_on_error_(other.impl()->exit_on_error_), kernel_label_(other.impl()->kernel_label_), device_(other.impl()->device_), + assigned_device_(other.impl()->assigned_device_), colocation_constraints_(other.impl()->colocation_constraints_), disable_shape_inference_(other.impl()->disable_shape_inference_) {} @@ -163,6 +166,7 @@ Scope::Impl::Impl(const Scope& other, Tags::SingleUseScope, exit_on_error_(other.impl()->exit_on_error_), kernel_label_(other.impl()->kernel_label_), device_(other.impl()->device_), + assigned_device_(other.impl()->assigned_device_), colocation_constraints_(other.impl()->colocation_constraints_), disable_shape_inference_(other.impl()->disable_shape_inference_) {} @@ -178,6 +182,7 @@ Scope::Impl::Impl(const Scope& other, Tags::ExitOnError) exit_on_error_(true), kernel_label_(other.impl()->kernel_label_), device_(other.impl()->device_), + assigned_device_(other.impl()->assigned_device_), colocation_constraints_(other.impl()->colocation_constraints_), disable_shape_inference_(other.impl()->disable_shape_inference_) {} @@ -194,6 +199,7 @@ Scope::Impl::Impl(const Scope& other, Tags::KernelLabel, exit_on_error_(other.impl()->exit_on_error_), kernel_label_(kernel_label), device_(other.impl()->device_), + assigned_device_(other.impl()->assigned_device_), colocation_constraints_(other.impl()->colocation_constraints_), disable_shape_inference_(other.impl()->disable_shape_inference_) {} @@ -210,12 +216,30 @@ Scope::Impl::Impl(const Scope& other, Tags::Colocate, exit_on_error_(other.impl()->exit_on_error_), kernel_label_(other.impl()->kernel_label_), device_(other.impl()->device_), + assigned_device_(other.impl()->assigned_device_), colocation_constraints_( clear_colocations ? std::unordered_set() : other.impl()->GetColocationConstraints(colocate_with_op)), disable_shape_inference_(other.impl()->disable_shape_inference_) {} +Scope::Impl::Impl(const Scope& other, Tags::AssignedDevice, + const string& assigned_device) + : graph_(other.impl()->graph_), + status_(other.impl()->status_), + name_map_(other.impl()->name_map_), + refiner_(other.impl()->refiner_), + scope_used_(other.impl()->scope_used_), + control_deps_(other.impl()->control_deps_), + name_(other.impl()->name_), + op_name_(other.impl()->op_name_), + exit_on_error_(other.impl()->exit_on_error_), + kernel_label_(other.impl()->kernel_label_), + device_(other.impl()->device_), + assigned_device_(assigned_device), + colocation_constraints_(other.impl()->colocation_constraints_), + disable_shape_inference_(other.impl()->disable_shape_inference_) {} + std::unordered_set Scope::Impl::GetColocationConstraints( const Operation& colocate_with_op) const { std::unordered_set current_constraints(colocation_constraints_); @@ -225,7 +249,7 @@ std::unordered_set Scope::Impl::GetColocationConstraints( for (const string& entry : node_constraints) { StringPiece s(entry); if (str_util::ConsumePrefix(&s, kColocationGroupPrefix)) { - current_constraints.insert(std::string(s)); + current_constraints.emplace(s); } } } else { @@ -299,6 +323,9 @@ void Scope::UpdateBuilder(NodeBuilder* builder) const { if (!impl()->device_.empty()) { builder->Device(impl()->device_); } + if (!impl()->assigned_device_.empty()) { + builder->AssignedDevice(impl()->assigned_device_); + } } string Scope::Impl::GetUniqueName(const string& prefix, @@ -394,6 +421,10 @@ Scope Scope::WithDevice(const string& device) const { return Scope(new Impl(*this, Impl::Tags::Device(), device)); } +Scope Scope::WithAssignedDevice(const string& assigned_device) const { + return Scope(new Impl(*this, Impl::Tags::AssignedDevice(), assigned_device)); +} + Scope Scope::ColocateWith(const Operation& op) const { return Scope(new Impl(*this, Impl::Tags::Colocate(), op, /* clear_colocations */ false)); diff --git a/tensorflow/cc/framework/scope.h b/tensorflow/cc/framework/scope.h index 30c32bd44b0f22d6b29dd3836d431807d0216818..e307d8989b6647dfac8d2691ed2171c86b7f3a7c 100644 --- a/tensorflow/cc/framework/scope.h +++ b/tensorflow/cc/framework/scope.h @@ -133,6 +133,10 @@ class Scope { /// the device field set to 'device'. Scope WithDevice(const string& device) const; + /// Returns a new scope. All ops created within the returned scope will have + /// their assigned device set to `assigned_device`. + Scope WithAssignedDevice(const string& assigned_device) const; + /// Return a new scope. All ops created within the returned scope will be /// co-located on the device where op is placed. /// NOTE: This function is intended to be use internal libraries only for diff --git a/tensorflow/cc/framework/scope_internal.h b/tensorflow/cc/framework/scope_internal.h index 58adaef2e942a7fa6b0ce8d5534ac3e2fd380580..514e02e84146b6d95147d83182e5d9a07509cfa1 100644 --- a/tensorflow/cc/framework/scope_internal.h +++ b/tensorflow/cc/framework/scope_internal.h @@ -26,6 +26,8 @@ class ShapeRefiner; // graph, status, name_map, and refiner. // This is intended to enable the C API (which are used by other language // bindings) to create a Scope and access C++ functionality (i.e. gradients). +// +// Shape inference is disabled if `refiner` is nullptr. Scope NewInternalScope(Graph* graph, Status* status, ShapeRefiner* refiner); class Scope::Impl { @@ -58,6 +60,7 @@ class Scope::Impl { enum class ExitOnError; enum class KernelLabel; enum class Colocate; + enum class AssignedDevice; }; Impl(Graph* graph, Status* status, NameMap* name_map, ShapeRefiner* refiner, @@ -74,6 +77,7 @@ class Scope::Impl { Impl(const Scope& other, Tags::KernelLabel, const string& kernel_label); Impl(const Scope& other, Tags::Colocate, const Operation& colocate_with_op, bool clear_colocations); + Impl(const Scope& other, Tags::AssignedDevice, const string& assigned_device); std::unordered_set GetColocationConstraints( const Operation& colocate_with_op) const; @@ -107,6 +111,7 @@ class Scope::Impl { const bool exit_on_error_ = false; const string kernel_label_ = ""; const string device_ = ""; + const string assigned_device_ = ""; const std::unordered_set colocation_constraints_; // If true, Scope::DoShapeInference() always returns Status:OK(). diff --git a/tensorflow/cc/gradients/nn_grad.cc b/tensorflow/cc/gradients/nn_grad.cc index 588e96cb196189780037f66266484962ba0385e4..2a32a2ed6f7862a29f4ce3d1aba5fdbc86adc670 100644 --- a/tensorflow/cc/gradients/nn_grad.cc +++ b/tensorflow/cc/gradients/nn_grad.cc @@ -143,6 +143,33 @@ Status Relu6GradHelper(const Scope& scope, const Operation& op, } REGISTER_GRADIENT_OP("Relu6", Relu6GradHelper); +Status LeakyReluGradHelper(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + float alpha; + TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "alpha", &alpha)); + internal::LeakyReluGrad::Attrs attrs; + auto dx = internal::LeakyReluGrad(scope, grad_inputs[0], op.input(0), + attrs.Alpha(alpha)); + grad_outputs->push_back(dx); + return scope.status(); +} +REGISTER_GRADIENT_OP("LeakyRelu", LeakyReluGradHelper); + +Status LeakyReluGradGradHelper(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + float alpha; + TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "alpha", &alpha)); + internal::LeakyReluGrad::Attrs attrs; + auto dx = internal::LeakyReluGrad(scope, grad_inputs[0], op.input(1), + attrs.Alpha(alpha)); + grad_outputs->push_back(dx); + grad_outputs->push_back(NoGradient()); + return scope.status(); +} +REGISTER_GRADIENT_OP("LeakyReluGrad", LeakyReluGradGradHelper); + Status EluGradHelper(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { diff --git a/tensorflow/cc/gradients/nn_grad_test.cc b/tensorflow/cc/gradients/nn_grad_test.cc index aa72cf7ba2a958f54d50b59f0edaefb27edf0e86..f5a09e09dcda3e06c71d44d5fa5a1b121a9ade58 100644 --- a/tensorflow/cc/gradients/nn_grad_test.cc +++ b/tensorflow/cc/gradients/nn_grad_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/cc/framework/gradient_checker.h" #include "tensorflow/cc/framework/testutil.h" #include "tensorflow/cc/gradients/grad_testutil.h" +#include "tensorflow/cc/ops/nn_ops_internal.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -160,6 +161,32 @@ TEST_F(NNGradTest, Relu6Grad) { RunTest(x, x_init_value, y, shape); } +TEST_F(NNGradTest, LeakyReluGrad) { + TensorShape shape({5, 2}); + auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)); + auto y = ops::internal::LeakyRelu(scope_, x); + // Avoid input values where Leaky ReLU gradient is not well defined (around + // zero). + Tensor x_init_value = test::AsTensor( + {-0.9f, -0.7f, -0.5f, -0.3f, -0.1f, 0.1f, 0.3f, 0.5f, 0.7f, 0.9f}, + {5, 2}); + RunTest(x, x_init_value, y, shape); +} + +TEST_F(NNGradTest, LeakyReluGradGrad) { + TensorShape shape({5, 2}); + auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)); + // Avoid input values where Leaky ReLU gradient is not well defined (around + // zero). + Tensor x_init_value = test::AsTensor( + {2.3f, 1.9f, 1.5f, 1.1f, 0.7f, 0.3f, -0.1f, -0.5f, -0.9f, -1.3f}, {5, 2}); + Tensor features = test::AsTensor( + {-0.9f, -0.7f, -0.5f, -0.3f, -0.1f, 0.1f, 0.3f, 0.5f, 0.7f, 0.9f}, + {5, 2}); + auto y = ops::internal::LeakyReluGrad(scope_, x, features); + RunTest(x, x_init_value, y, shape); +} + TEST_F(NNGradTest, EluGrad) { TensorShape shape({5, 2}); auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)); diff --git a/tensorflow/cc/saved_model/loader.cc b/tensorflow/cc/saved_model/loader.cc index 222e7698818204b01ad69f610bdbf5d59ffa74dd..c6abe2f41b9b5ec2faee6f65b429ff606f8ac08e 100644 --- a/tensorflow/cc/saved_model/loader.cc +++ b/tensorflow/cc/saved_model/loader.cc @@ -148,7 +148,7 @@ Status RunMainOp(const RunOptions& run_options, const string& export_dir, AddAssetsTensorsToInputs(export_dir, asset_file_defs, &inputs); RunMetadata run_metadata; const StringPiece main_op_name = main_op_it->second.node_list().value(0); - return RunOnce(run_options, inputs, {}, {main_op_name.ToString()}, + return RunOnce(run_options, inputs, {}, {string(main_op_name)}, nullptr /* outputs */, &run_metadata, session); } return Status::OK(); @@ -187,7 +187,7 @@ Status RunRestore(const RunOptions& run_options, const string& export_dir, AddAssetsTensorsToInputs(export_dir, asset_file_defs, &inputs); RunMetadata run_metadata; - return RunOnce(run_options, inputs, {}, {restore_op_name.ToString()}, + return RunOnce(run_options, inputs, {}, {string(restore_op_name)}, nullptr /* outputs */, &run_metadata, session); } diff --git a/tensorflow/compiler/aot/BUILD b/tensorflow/compiler/aot/BUILD index 59b961cdd9dac8a1c305a3f5f520ca1b68148cca..6c29f09cde7ee17c11cb44ce48d8e9128daae4d0 100644 --- a/tensorflow/compiler/aot/BUILD +++ b/tensorflow/compiler/aot/BUILD @@ -56,6 +56,7 @@ cc_library( "//tensorflow/core:protos_all_cc", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", ], ) @@ -191,13 +192,13 @@ cc_library( srcs = ["embedded_protocol_buffers.cc"], hdrs = ["embedded_protocol_buffers.h"], deps = [ - "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/core:lib", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", "@llvm//:core", "@llvm//:support", "@llvm//:target", diff --git a/tensorflow/compiler/aot/codegen.cc b/tensorflow/compiler/aot/codegen.cc index e77a8fecf09fa037726b0baf5d2f38aeae0ef155..b17bc658fa06b9feb7edb292bd89ef31e6309169 100644 --- a/tensorflow/compiler/aot/codegen.cc +++ b/tensorflow/compiler/aot/codegen.cc @@ -20,8 +20,10 @@ limitations under the License. #include #include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/strings/str_replace.h" +#include "absl/types/span.h" #include "tensorflow/compiler/aot/embedded_protocol_buffers.h" #include "tensorflow/compiler/tf2xla/cpu_function_runtime.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" @@ -30,8 +32,6 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/lib/strings/strcat.h" namespace tensorflow { namespace tfcompile { @@ -135,12 +135,12 @@ Status AddRewritesForShape(int i, const xla::Shape& shape, indices = "[0]"; } else { for (int dim = 0; dim < shape.dimensions_size(); ++dim) { - dim_vars.push_back(strings::StrCat("size_t dim", dim)); - dim_sizes += strings::StrCat("[", shape.dimensions(dim), "]"); - indices += strings::StrCat("[dim", dim, "]"); + dim_vars.push_back(absl::StrCat("size_t dim", dim)); + dim_sizes += absl::StrCat("[", shape.dimensions(dim), "]"); + indices += absl::StrCat("[dim", dim, "]"); } } - rewrites->push_back({"{{I}}", strings::StrCat(i)}); + rewrites->push_back({"{{I}}", absl::StrCat(i)}); rewrites->push_back({"{{TYPE}}", type}); rewrites->push_back({"{{DIM_VARS}}", absl::StrJoin(dim_vars, ", ")}); rewrites->push_back({"{{DIM_SIZES}}", dim_sizes}); @@ -194,7 +194,7 @@ Status GenArgMethods(const tf2xla::Config& config, const xla::ProgramShape& ps, arg_data({{I}}))){{INDICES}}; } )"; - *methods += RewriteWithName(strings::StrCat(i), code, rewrites); + *methods += RewriteWithName(absl::StrCat(i), code, rewrites); if (!config.feed(i).name().empty()) { *methods += RewriteWithName("_" + config.feed(i).name(), code, rewrites); } @@ -235,7 +235,7 @@ Status GenResultMethods(const tf2xla::Config& config, result_data({{I}}))){{INDICES}}; } )"; - *methods += RewriteWithName(strings::StrCat(i), code, rewrites); + *methods += RewriteWithName(absl::StrCat(i), code, rewrites); if (!config.fetch(i).name().empty()) { *methods += RewriteWithName("_" + config.fetch(i).name(), code, rewrites); } @@ -304,8 +304,8 @@ std::vector BufferInfosToCppExpression( string encoded_second_as_str = encoded.second == ~0ULL ? "~0ULL" - : strings::StrCat(encoded.second, "ULL"); - return strings::StrCat( + : absl::StrCat(encoded.second, "ULL"); + return absl::StrCat( "::tensorflow::cpu_function_runtime::BufferInfo({", encoded.first, "ULL, ", encoded_second_as_str, "})"); }); @@ -352,13 +352,13 @@ Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config, // Create rewrite strings for namespace start and end. string ns_start; for (const string& n : opts.namespaces) { - ns_start += strings::StrCat("namespace ", n, " {\n"); + ns_start += absl::StrCat("namespace ", n, " {\n"); } ns_start += "\n"; string ns_end("\n"); for (int i = opts.namespaces.size() - 1; i >= 0; --i) { const string& n = opts.namespaces[i]; - ns_end += strings::StrCat("} // end namespace ", n, "\n"); + ns_end += absl::StrCat("} // end namespace ", n, "\n"); } // Generate metadata. @@ -568,10 +568,10 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction { )"; // The replacement strategy is naive, but good enough for our purposes. const std::vector> rewrites = { - {"{{ARG_BYTES_ALIGNED}}", strings::StrCat(arg_bytes_aligned)}, - {"{{ARG_BYTES_TOTAL}}", strings::StrCat(arg_bytes_total)}, + {"{{ARG_BYTES_ALIGNED}}", absl::StrCat(arg_bytes_aligned)}, + {"{{ARG_BYTES_TOTAL}}", absl::StrCat(arg_bytes_total)}, {"{{ARG_NAMES_CODE}}", arg_names_code}, - {"{{ARG_NUM}}", strings::StrCat(arg_index_table.size())}, + {"{{ARG_NUM}}", absl::StrCat(arg_index_table.size())}, {"{{ARG_INDEX_TABLE}}", absl::StrJoin(arg_index_table, ", ")}, {"{{ASSIGN_PROFILE_COUNTERS_SIZE}}", assign_profile_counters_size}, {"{{CLASS}}", opts.class_name}, @@ -590,11 +590,11 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction { {"{{PROGRAM_SHAPE}}", xla::ShapeUtil::HumanString(ps)}, {"{{PROGRAM_SHAPE_SHIM_EXPRESSION}}", metadata_result.program_shape_access_shim}, - {"{{RESULT_INDEX}}", strings::StrCat(result_index)}, + {"{{RESULT_INDEX}}", absl::StrCat(result_index)}, {"{{RESULT_NAMES_CODE}}", result_names_code}, - {"{{TEMP_BYTES_ALIGNED}}", strings::StrCat(temp_bytes_aligned)}, - {"{{TEMP_BYTES_TOTAL}}", strings::StrCat(temp_bytes_total)}, - {"{{NUM_BUFFERS}}", strings::StrCat(buffer_infos.size())}, + {"{{TEMP_BYTES_ALIGNED}}", absl::StrCat(temp_bytes_aligned)}, + {"{{TEMP_BYTES_TOTAL}}", absl::StrCat(temp_bytes_total)}, + {"{{NUM_BUFFERS}}", absl::StrCat(buffer_infos.size())}, {"{{BUFFER_INFOS_AS_STRING}}", absl::StrJoin(buffer_infos_as_strings, ",\n")}}; absl::StrReplaceAll(rewrites, header); @@ -602,13 +602,13 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction { } static string CreateUniqueIdentifier(const CodegenOpts& opts, - StringPiece suffix) { + absl::string_view suffix) { string result = "__tfcompile"; for (const string& n : opts.namespaces) { - strings::StrAppend(&result, "_", n); + absl::StrAppend(&result, "_", n); } - strings::StrAppend(&result, "_", opts.class_name, "_", suffix); + absl::StrAppend(&result, "_", opts.class_name, "_", suffix); return result; } @@ -678,7 +678,7 @@ Status ParseCppClass(const string& cpp_class, string* class_name, return Status::OK(); } -Status ValidateCppIdent(StringPiece ident, StringPiece msg) { +Status ValidateCppIdent(absl::string_view ident, absl::string_view msg) { if (ident.empty()) { return errors::InvalidArgument("empty identifier: ", msg); } diff --git a/tensorflow/compiler/aot/codegen.h b/tensorflow/compiler/aot/codegen.h index 83f2d3ee11d09d66f16d7ecdc11945ebe994a82a..90410c46a8e36e44454f1219ad76d0fb0937070d 100644 --- a/tensorflow/compiler/aot/codegen.h +++ b/tensorflow/compiler/aot/codegen.h @@ -19,9 +19,9 @@ limitations under the License. #include #include +#include "absl/strings/string_view.h" #include "tensorflow/compiler/aot/compile.h" #include "tensorflow/compiler/tf2xla/tf2xla.pb.h" -#include "tensorflow/core/lib/core/stringpiece.h" namespace tensorflow { namespace tfcompile { @@ -96,7 +96,7 @@ Status ParseCppClass(const string& cpp_class, string* class_name, // ValidateCppIdent returns OK iff ident is a valid C++ identifier. The msg is // appended to error messages. -Status ValidateCppIdent(StringPiece ident, StringPiece msg); +Status ValidateCppIdent(absl::string_view ident, absl::string_view msg); } // namespace tfcompile } // namespace tensorflow diff --git a/tensorflow/compiler/aot/codegen_test.cc b/tensorflow/compiler/aot/codegen_test.cc index e3a53edb7368c209bea16a9e34b1f452a8ff4bf8..bb288d23000527be74f01630d20bbf82e50007ce 100644 --- a/tensorflow/compiler/aot/codegen_test.cc +++ b/tensorflow/compiler/aot/codegen_test.cc @@ -19,11 +19,11 @@ limitations under the License. #include #include "absl/strings/match.h" +#include "absl/strings/string_view.h" #include "llvm/Support/TargetSelect.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/test.h" diff --git a/tensorflow/compiler/aot/embedded_protocol_buffers.cc b/tensorflow/compiler/aot/embedded_protocol_buffers.cc index 1401aae7586bfd40ec209b0ae591d6ab69d0a26b..3c32d533f63f202fc9409f36709e0d29d1d7e002 100644 --- a/tensorflow/compiler/aot/embedded_protocol_buffers.cc +++ b/tensorflow/compiler/aot/embedded_protocol_buffers.cc @@ -38,11 +38,11 @@ using xla::llvm_ir::AsStringRef; static void AddEmbeddedProtocolBufferToLlvmModule( llvm::Module* module, const ::tensorflow::protobuf::MessageLite& proto, - StringPiece unique_identifier, string* protobuf_array_symbol_name, + absl::string_view unique_identifier, string* protobuf_array_symbol_name, int64* protobuf_array_size) { string protobuf_array_contents = proto.SerializeAsString(); *protobuf_array_symbol_name = - strings::StrCat(unique_identifier, "_protobuf_array_contents"); + absl::StrCat(unique_identifier, "_protobuf_array_contents"); *protobuf_array_size = protobuf_array_contents.size(); llvm::Constant* protobuf_array_initializer = @@ -55,9 +55,9 @@ static void AddEmbeddedProtocolBufferToLlvmModule( protobuf_array_initializer, AsStringRef(*protobuf_array_symbol_name)); } -static string CreateCPPShimExpression(StringPiece qualified_cpp_protobuf_name, - StringPiece protobuf_array_symbol_name, - int64 protobuf_array_size) { +static string CreateCPPShimExpression( + absl::string_view qualified_cpp_protobuf_name, + absl::string_view protobuf_array_symbol_name, int64 protobuf_array_size) { string code = "[]() {\n" " {{PROTOBUF_NAME}}* proto = new {{PROTOBUF_NAME}};\n" @@ -68,9 +68,9 @@ static string CreateCPPShimExpression(StringPiece qualified_cpp_protobuf_name, return absl::StrReplaceAll( code, { - {"{{ARRAY_SYMBOL}}", strings::StrCat(protobuf_array_symbol_name)}, - {"{{ARRAY_SIZE}}", strings::StrCat(protobuf_array_size)}, - {"{{PROTOBUF_NAME}}", strings::StrCat(qualified_cpp_protobuf_name)}, + {"{{ARRAY_SYMBOL}}", absl::StrCat(protobuf_array_symbol_name)}, + {"{{ARRAY_SIZE}}", absl::StrCat(protobuf_array_size)}, + {"{{PROTOBUF_NAME}}", absl::StrCat(qualified_cpp_protobuf_name)}, }); } @@ -93,7 +93,7 @@ static StatusOr CodegenModule(llvm::TargetMachine* target_machine, } static StatusOr> -GetTargetMachineFromTriple(StringPiece target_triple) { +GetTargetMachineFromTriple(absl::string_view target_triple) { std::string error; std::string normalized_triple = llvm::Triple::normalize(AsStringRef(absl::string_view(target_triple))); @@ -110,8 +110,8 @@ GetTargetMachineFromTriple(StringPiece target_triple) { } StatusOr CreateEmbeddedProtocolBuffers( - StringPiece target_triple, - gtl::ArraySlice protobufs_to_embed) { + absl::string_view target_triple, + absl::Span protobufs_to_embed) { TF_ASSIGN_OR_RETURN(std::unique_ptr target_machine, GetTargetMachineFromTriple(target_triple)); @@ -135,8 +135,8 @@ StatusOr CreateEmbeddedProtocolBuffers( protobuf_to_embed.qualified_cpp_protobuf_name, protobuf_array_symbol_name, protobuf_array_size); - cpp_variable_decl = strings::StrCat("extern \"C\" char ", - protobuf_array_symbol_name, "[];"); + cpp_variable_decl = + absl::StrCat("extern \"C\" char ", protobuf_array_symbol_name, "[];"); } else { cpp_shim = "nullptr"; } diff --git a/tensorflow/compiler/aot/embedded_protocol_buffers.h b/tensorflow/compiler/aot/embedded_protocol_buffers.h index 4e194a6aba9a9efcad27c47c42e148d8e537ae68..cf5c04ac4bdff73b76a365c346f7db60ce2d8197 100644 --- a/tensorflow/compiler/aot/embedded_protocol_buffers.h +++ b/tensorflow/compiler/aot/embedded_protocol_buffers.h @@ -20,8 +20,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_AOT_EMBEDDED_PROTOCOL_BUFFERS_H_ #define TENSORFLOW_COMPILER_AOT_EMBEDDED_PROTOCOL_BUFFERS_H_ +#include "absl/types/span.h" #include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/protobuf.h" namespace tensorflow { @@ -83,8 +83,8 @@ struct ProtobufToEmbed { // is stored in the object_file_data field in the returned // EmbeddedProtocolBuffers instance. StatusOr CreateEmbeddedProtocolBuffers( - StringPiece target_triple, - gtl::ArraySlice protobufs_to_embed); + absl::string_view target_triple, + absl::Span protobufs_to_embed); } // namespace tfcompile } // namespace tensorflow diff --git a/tensorflow/compiler/aot/tests/BUILD b/tensorflow/compiler/aot/tests/BUILD index 7364d63b53a83a44bd99ed190b07a26073a484ce..10fa33ab5e84dcbc1629bee6214e8969046f19c2 100644 --- a/tensorflow/compiler/aot/tests/BUILD +++ b/tensorflow/compiler/aot/tests/BUILD @@ -25,6 +25,7 @@ test_suite( ":test_graph_tfmatmul_test", ":test_graph_tfmatmulandadd_test", ":test_graph_tfsplits_test", + ":test_graph_tftop_k_test", ":tfcompile_test", ], ) @@ -42,6 +43,7 @@ py_binary( "//tensorflow/python:control_flow_ops", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:math_ops", + "//tensorflow/python:nn_ops", "//tensorflow/python:platform", "//tensorflow/python:session", "//tensorflow/python:training", @@ -66,8 +68,14 @@ genrule( "test_graph_tfmatmul.pb", "test_graph_tfmatmulandadd.pb", "test_graph_tfsplits.pb", + "test_graph_tftop_k.pb", ], - cmd = "$(location :make_test_graphs) --out_dir $(@D)", + # Set CUDA_VISIBLE_DEVICES='' to prevent the code we launch from using any + # GPUs which might be present. This is important because builds may run + # concurrently with tests, and tests need to be able to assume that they + # have control of the full GPU. + cmd = "CUDA_VISIBLE_DEVICES='' " + + "$(location :make_test_graphs) --out_dir $(@D)", tags = ["manual"], tools = [":make_test_graphs"], ) @@ -187,6 +195,9 @@ tf_library( cpp_class = "MatMulAndAddCompWithProfiling", enable_xla_hlo_profiling = True, graph = "test_graph_tfmatmulandadd.pb", + tags = [ + "manual", + ], ) tf_library( @@ -200,6 +211,17 @@ tf_library( ], ) +tf_library( + name = "test_graph_tftop_k", + testonly = 1, + config = "test_graph_tftop_k.config.pbtxt", + cpp_class = "TopKComp", + graph = "test_graph_tftop_k.pb", + tags = [ + "manual", + ], +) + tf_cc_test( name = "tfcompile_test", srcs = ["tfcompile_test.cc"], @@ -218,11 +240,13 @@ tf_cc_test( ":test_graph_tfmatmulandadd", ":test_graph_tfmatmulandadd_with_profiling", ":test_graph_tfsplits", + ":test_graph_tftop_k", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo_profile_printer", "//tensorflow/core:lib", + "//tensorflow/core:regexp_internal", "//tensorflow/core:test", "//tensorflow/core:test_main", "//third_party/eigen3", diff --git a/tensorflow/compiler/aot/tests/make_test_graphs.py b/tensorflow/compiler/aot/tests/make_test_graphs.py index 9ec7df163b1425f917e9ec51559efad3e6f05e75..64b861a73091642b03573543a5c55618bf33915d 100644 --- a/tensorflow/compiler/aot/tests/make_test_graphs.py +++ b/tensorflow/compiler/aot/tests/make_test_graphs.py @@ -31,6 +31,7 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops from tensorflow.python.ops import variables from tensorflow.python.platform import app from tensorflow.python.training import saver as saver_lib @@ -46,7 +47,7 @@ def tfadd(_): def tfadd_with_ckpt(out_dir): x = array_ops.placeholder(dtypes.int32, name='x_hold') - y = variables.Variable(constant_op.constant([0]), name='y_saved') + y = variables.VariableV1(constant_op.constant([0]), name='y_saved') math_ops.add(x, y, name='x_y_sum') init_op = variables.initialize_all_variables() @@ -61,7 +62,7 @@ def tfadd_with_ckpt(out_dir): def tfadd_with_ckpt_saver(out_dir): x = array_ops.placeholder(dtypes.int32, name='x_hold') - y = variables.Variable(constant_op.constant([0]), name='y_saved') + y = variables.VariableV1(constant_op.constant([0]), name='y_saved') math_ops.add(x, y, name='x_y_sum') init_op = variables.initialize_all_variables() @@ -142,6 +143,12 @@ def tfsplits(_): array_ops.identity(y, name='result') +def tftop_k(_): + x = array_ops.placeholder(dtypes.int32, shape=[5], name='x') + output = nn_ops.top_k(x, 2, name='values') + array_ops.identity(output[1], name='indices') + + def write_graph(build_graph, out_dir): """Build a graph using build_graph and write it out.""" g = ops.Graph() @@ -163,6 +170,7 @@ def main(_): write_graph(tfmatmul, FLAGS.out_dir) write_graph(tfmatmulandadd, FLAGS.out_dir) write_graph(tfsplits, FLAGS.out_dir) + write_graph(tftop_k, FLAGS.out_dir) if __name__ == '__main__': diff --git a/tensorflow/compiler/aot/tests/test_graph_tftop_k.config.pbtxt b/tensorflow/compiler/aot/tests/test_graph_tftop_k.config.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..6b4ac2d7cbb517be841932b1cfae9e28decdf8d3 --- /dev/null +++ b/tensorflow/compiler/aot/tests/test_graph_tftop_k.config.pbtxt @@ -0,0 +1,13 @@ +# Text form of tensorflow.tf2xla.Config proto. +feed { + id { node_name: "x" } + shape { + dim { size: 5 } + } +} +fetch { + id { node_name: "values" } +} +fetch { + id { node_name: "indices" } +} diff --git a/tensorflow/compiler/aot/tests/tfcompile_test.cc b/tensorflow/compiler/aot/tests/tfcompile_test.cc index dd2b151098f2054571ac32b8b506cbc00659588a..f10852c7850f61bfd8b99fa9f1648202d182085e 100644 --- a/tensorflow/compiler/aot/tests/tfcompile_test.cc +++ b/tensorflow/compiler/aot/tests/tfcompile_test.cc @@ -29,10 +29,12 @@ limitations under the License. #include "tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd.h" #include "tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd_with_profiling.h" #include "tensorflow/compiler/aot/tests/test_graph_tfsplits.h" +#include "tensorflow/compiler/aot/tests/test_graph_tftop_k.h" #include "tensorflow/compiler/xla/service/hlo_profile_printer.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/regexp.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { @@ -447,6 +449,30 @@ TEST(TFCompileTest, Splits) { EXPECT_NEAR(expected[3], fn.result0(1, 1), 1e4); } +TEST(TFCompileTest, TopK) { + Eigen::ThreadPool tp(1); + Eigen::ThreadPoolDevice device(&tp, tp.NumThreads()); + + TopKComp fn; + + fn.set_thread_pool(&device); + // x = [4, 1, 4, 4, 3] + fn.arg0(0) = 4; + fn.arg0(1) = 1; + fn.arg0(2) = 4; + fn.arg0(3) = 4; + fn.arg0(4) = 3; + + EXPECT_TRUE(fn.Run()); + EXPECT_EQ(fn.error_msg(), ""); + const int32 expected_values[] = {4, 4}; + const int32 expected_indices[] = {0, 2}; + EXPECT_EQ(expected_values[0], fn.result0(0)); + EXPECT_EQ(expected_values[1], fn.result0(1)); + EXPECT_EQ(expected_indices[0], fn.result1(0)); + EXPECT_EQ(expected_indices[1], fn.result1(1)); +} + TEST(TFCompileTest, AssertEqAndReturnDiff) { // Assert is converted into a no-op in XLA, so there is no failure even if the // two args are different. @@ -543,7 +569,13 @@ TEST(TFCompileTest, HloProfiling) { string hlo_profile_as_string = xla::PrintHloProfile(fn.hlo_profile_printer_data(), fn.profile_counters(), /*clock_rate_ghz=*/1.0); - VLOG(1) << "HLO profile string:\n" << hlo_profile_as_string; + VLOG(1) << "Original HLO profile string:\n" << hlo_profile_as_string; + + // Strip away identifier details from the profile string to avoid this test + // being a change detector for xla internals. Identifiers such as '%dot.0.7' + // just become '%dot'. + RE2::GlobalReplace(&hlo_profile_as_string, "(%[a-zA-Z0-9]*)[.0-9]*", "\\1"); + VLOG(1) << "Stripped HLO profile string:\n" << hlo_profile_as_string; std::vector hlo_profile_lines = absl::StrSplit(hlo_profile_as_string, '\n'); @@ -551,16 +583,14 @@ TEST(TFCompileTest, HloProfiling) { auto header = HasSubstr("Execution profile for"); auto total_cycles_profile_line = HasSubstr("[total]"); auto dot_profile_line = HasSubstr( - "%dot.0.4 = f32[2,2]{1,0} dot(f32[2,2]{1,0} %arg0.0.0, f32[2,2]{1,0} " - "%arg1.0.1)"); + "%dot = f32[2,2]{1,0} dot(f32[2,2]{1,0} %arg0, f32[2,2]{1,0} %arg1)"); auto add_profile_line = HasSubstr( - "%add.0.6 = f32[2,2]{1,0} add(f32[2,2]{1,0} %arg0.0.0, f32[2,2]{1,0} " - "%arg1.0.1)"); + "%add = f32[2,2]{1,0} add(f32[2,2]{1,0} %arg0, f32[2,2]{1,0} %arg1)"); auto tuple_profile_line = HasSubstr( - "%tuple.0.8 = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(f32[2,2]{1,0} " - "%dot.0.4, f32[2,2]{1,0} %add.0.6)"); - auto arg0_profile_line = HasSubstr("%arg0.0.0 = f32[2,2]{1,0} parameter(0)"); - auto arg1_profile_line = HasSubstr("%arg1.0.1 = f32[2,2]{1,0} parameter(1)"); + "%tuple = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(f32[2,2]{1,0} %dot, " + "f32[2,2]{1,0} %add)"); + auto arg0_profile_line = HasSubstr("%arg0 = f32[2,2]{1,0} parameter(0)"); + auto arg1_profile_line = HasSubstr("%arg1 = f32[2,2]{1,0} parameter(1)"); EXPECT_THAT(hlo_profile_lines, IsSupersetOf({header, total_cycles_profile_line, dot_profile_line, diff --git a/tensorflow/compiler/aot/tfcompile.bzl b/tensorflow/compiler/aot/tfcompile.bzl index 326f73b975aec3a7a6bc7cdc9a92f540ad545ad6..859c84bb91657422b830255b0217f8946d351458 100644 --- a/tensorflow/compiler/aot/tfcompile.bzl +++ b/tensorflow/compiler/aot/tfcompile.bzl @@ -105,12 +105,18 @@ def tf_library( freeze_file = freeze_name + ".pb" # First run tfcompile to generate the list of out_nodes. + # + # Here and below, we set CUDA_VISIBLE_DEVICES='' to prevent the code we + # launch from using any GPUs which might be present. This is important + # because builds may run concurrently with tests, and tests need to be + # able to assume that they have control of the full GPU. out_nodes_file = "out_nodes_" + freeze_name native.genrule( name = ("gen_" + out_nodes_file), srcs = [config], outs = [out_nodes_file], - cmd = ("$(location " + tfcompile_tool + ")" + + cmd = ("CUDA_VISIBLE_DEVICES='' " + + "$(location " + tfcompile_tool + ")" + " --config=$(location " + config + ")" + " --dump_fetch_nodes > $@"), tools = [tfcompile_tool], @@ -142,9 +148,12 @@ def tf_library( out_nodes_file, ] + freeze_saver_srcs, outs = [freeze_file], - cmd = ("$(location " + - "//tensorflow/python/tools:freeze_graph)" + - freeze_args), + cmd = ( + "CUDA_VISIBLE_DEVICES='' " + + "$(location " + + "//tensorflow/python/tools:freeze_graph)" + + freeze_args + ), tools = ["//tensorflow/python/tools:freeze_graph"], tags = tags, ) @@ -177,16 +186,19 @@ def tf_library( metadata_object_file, function_object_file, ], - cmd = ("$(location " + tfcompile_tool + ")" + - " --graph=$(location " + tfcompile_graph + ")" + - " --config=$(location " + config + ")" + - " --entry_point=" + ep + - " --cpp_class=" + cpp_class + - " --target_triple=" + target_llvm_triple() + - " --out_header=$(@D)/" + header_file + - " --out_metadata_object=$(@D)/" + metadata_object_file + - " --out_function_object=$(@D)/" + function_object_file + - " " + flags + " " + profiling_flag), + cmd = ( + "CUDA_VISIBLE_DEVICES='' " + + "$(location " + tfcompile_tool + ")" + + " --graph=$(location " + tfcompile_graph + ")" + + " --config=$(location " + config + ")" + + " --entry_point=" + ep + + " --cpp_class=" + cpp_class + + " --target_triple=" + target_llvm_triple() + + " --out_header=$(@D)/" + header_file + + " --out_metadata_object=$(@D)/" + metadata_object_file + + " --out_function_object=$(@D)/" + function_object_file + + " " + flags + " " + profiling_flag + ), tools = [tfcompile_tool], visibility = visibility, testonly = testonly, @@ -216,14 +228,17 @@ def tf_library( outs = [ session_module_pb, ], - cmd = ("$(location " + tfcompile_tool + ")" + - " --graph=$(location " + tfcompile_graph + ")" + - " --config=$(location " + config + ")" + - " --entry_point=" + ep + - " --cpp_class=" + cpp_class + - " --target_triple=" + target_llvm_triple() + - " --out_session_module=$(@D)/" + session_module_pb + - " " + flags), + cmd = ( + "CUDA_VISIBLE_DEVICES='' " + + "$(location " + tfcompile_tool + ")" + + " --graph=$(location " + tfcompile_graph + ")" + + " --config=$(location " + config + ")" + + " --entry_point=" + ep + + " --cpp_class=" + cpp_class + + " --target_triple=" + target_llvm_triple() + + " --out_session_module=$(@D)/" + session_module_pb + + " " + flags + ), tools = [tfcompile_tool], visibility = visibility, testonly = testonly, @@ -258,6 +273,7 @@ def tf_library( "//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_1d", "//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_2d", "//tensorflow/compiler/xla/service/cpu:runtime_conv2d", + "//tensorflow/compiler/xla/service/cpu:runtime_key_value_sort", "//tensorflow/compiler/xla/service/cpu:runtime_matmul", "//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_conv2d", "//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_matmul", diff --git a/tensorflow/compiler/aot/tfcompile_main.cc b/tensorflow/compiler/aot/tfcompile_main.cc index f3c44e9dda8ce96a268420a7f4d0f22e50ddfe41..b95b063348c5cdfdcaed635ba527e9f0bfd6092d 100644 --- a/tensorflow/compiler/aot/tfcompile_main.cc +++ b/tensorflow/compiler/aot/tfcompile_main.cc @@ -20,6 +20,7 @@ limitations under the License. #include "absl/strings/match.h" #include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" #include "tensorflow/compiler/aot/codegen.h" #include "tensorflow/compiler/aot/compile.h" #include "tensorflow/compiler/aot/flags.h" @@ -34,7 +35,6 @@ limitations under the License. #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/tensor_id.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/init_main.h" @@ -92,8 +92,9 @@ Status Main(const MainFlags& flags) { // Write output files. Env* env = Env::Default(); const std::vector& obj = compile_result.aot->object_file_data(); - TF_RETURN_IF_ERROR(WriteStringToFile(env, flags.out_function_object, - StringPiece(obj.data(), obj.size()))); + TF_RETURN_IF_ERROR( + WriteStringToFile(env, flags.out_function_object, + absl::string_view(obj.data(), obj.size()))); CodegenOpts codegen_opts; codegen_opts.gen_name_to_index = flags.gen_name_to_index; codegen_opts.gen_program_shape = flags.gen_program_shape; diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index df81f3c23e38a2ec2cea827cd0adb123855e7714..ced0cd03f74d147451ca2bf54108dc7517b50acd 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -26,6 +26,7 @@ load("//tensorflow:tensorflow.bzl", "tf_cc_test") load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured") load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test") +load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") # Target that bundles up the XLA CPU and GPU JIT devices. cc_library( @@ -50,7 +51,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":jit_compilation_passes", - "//tensorflow/compiler/jit/kernels:xla_launch_op", + "//tensorflow/compiler/jit/kernels:xla_ops", "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/compiler/xla/service:cpu_plugin", ], @@ -62,7 +63,7 @@ cc_library( visibility = ["//visibility:public"], deps = if_cuda([ ":jit_compilation_passes", - "//tensorflow/compiler/jit/kernels:xla_launch_op", + "//tensorflow/compiler/jit/kernels:xla_ops", "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/compiler/xla/service:gpu_plugin", ]), @@ -76,7 +77,7 @@ cc_library( deps = [ ":jit_compilation_passes", ":xla_device", - "//tensorflow/compiler/jit/kernels:xla_launch_op", + "//tensorflow/compiler/jit/kernels:xla_ops", "//tensorflow/compiler/jit/legacy_flags:xla_device_flags", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/kernels:xla_ops", @@ -94,7 +95,7 @@ cc_library( deps = [ ":jit_compilation_passes", ":xla_device", - "//tensorflow/compiler/jit/kernels:xla_launch_op", + "//tensorflow/compiler/jit/kernels:xla_ops", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/compiler/xla/service:gpu_plugin", # buildcleaner: keep @@ -111,7 +112,7 @@ cc_library( deps = [ ":jit_compilation_passes", ":xla_device", - "//tensorflow/compiler/jit/kernels:xla_launch_op", + "//tensorflow/compiler/jit/kernels:xla_ops", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/compiler/xla/service:interpreter_plugin", # buildcleaner: keep @@ -257,6 +258,7 @@ cc_library( "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", "//tensorflow/core/kernels:variable_ops", + "@com_google_absl//absl/container:flat_hash_map", ], ) @@ -265,6 +267,7 @@ cc_library( srcs = ["jit_compilation_pass_registration.cc"], deps = [ ":compilation_passes", + "//tensorflow/compiler/tf2xla:functionalize_control_flow_pass_registration", "//tensorflow/core:core_cpu_internal", ], alwayslink = 1, @@ -279,7 +282,7 @@ cc_library( deps = [ ":common", ":compilation_passes", - "//tensorflow/compiler/jit/kernels:xla_launch_op", + "//tensorflow/compiler/jit/kernels:xla_ops", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", @@ -321,6 +324,7 @@ cc_library( "//tensorflow/core:graph", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", @@ -340,7 +344,7 @@ tf_cc_test( "//tensorflow/cc:ops", "//tensorflow/cc:resource_variable_ops", "//tensorflow/cc:sendrecv_ops", - "//tensorflow/compiler/jit/kernels:xla_launch_op", + "//tensorflow/compiler/jit/kernels:xla_ops", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/core:core_cpu", @@ -358,18 +362,20 @@ tf_cc_test( cc_library( name = "compilation_passes", srcs = [ - "build_xla_launch_ops_pass.cc", + "build_xla_ops_pass.cc", "deadness_analysis.cc", "deadness_analysis_internal.h", "encapsulate_subgraphs_pass.cc", + "encapsulate_xla_computations_pass.cc", "mark_for_compilation_pass.cc", "mark_for_compilation_pass_test_helper.cc", "partially_decluster_pass.cc", ], hdrs = [ - "build_xla_launch_ops_pass.h", + "build_xla_ops_pass.h", "deadness_analysis.h", "encapsulate_subgraphs_pass.h", + "encapsulate_xla_computations_pass.h", "mark_for_compilation_pass.h", "mark_for_compilation_pass_test_helper.h", "partially_decluster_pass.h", @@ -379,12 +385,17 @@ cc_library( ":shape_inference_helpers", ":union_find", ":xla_cluster_util", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:ops", + "//tensorflow/cc:scope_internal", "//tensorflow/compiler/jit/graphcycles", + "//tensorflow/compiler/jit/legacy_flags:build_xla_ops_pass_flags", "//tensorflow/compiler/jit/legacy_flags:mark_for_compilation_pass_flags", "//tensorflow/compiler/jit/ops:xla_ops", "//tensorflow/compiler/tf2xla:dump_graph", "//tensorflow/compiler/tf2xla:resource_operation_table", "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/tf2xla/cc:xla_jit_ops", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:util", "//tensorflow/core:core_cpu", @@ -395,6 +406,10 @@ cc_library( "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", "//tensorflow/core/kernels:bounds_check", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", ], ) @@ -410,6 +425,7 @@ cc_library( "//tensorflow/core:graph", "//tensorflow/core:protos_all_cc", "//tensorflow/core/kernels:bounds_check", + "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", ], ) @@ -453,7 +469,7 @@ tf_cc_test( "//tensorflow/cc:function_ops", "//tensorflow/cc:ops", "//tensorflow/cc:sendrecv_ops", - "//tensorflow/compiler/jit/kernels:xla_launch_op", + "//tensorflow/compiler/jit/kernels:xla_ops", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/core:core_cpu", @@ -464,6 +480,7 @@ tf_cc_test( "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core:testlib", + "@com_google_absl//absl/container:flat_hash_map", ], ) @@ -471,22 +488,30 @@ tf_cc_test( name = "compilation_passes_test", size = "small", srcs = [ + "build_xla_ops_pass_test.cc", "encapsulate_subgraphs_pass_test.cc", + "encapsulate_xla_computations_pass_test.cc", "mark_for_compilation_pass_test.cc", "partially_decluster_pass_test.cc", ], deps = [ ":common", ":compilation_passes", + ":node_matchers", ":xla_cluster_util", + ":xla_cpu_device", + ":xla_gpu_device", "//tensorflow/cc:cc_ops", "//tensorflow/cc:cc_ops_internal", "//tensorflow/cc:function_ops", "//tensorflow/cc:ops", "//tensorflow/cc:resource_variable_ops", "//tensorflow/cc:sendrecv_ops", - "//tensorflow/compiler/jit/kernels:xla_launch_op", + "//tensorflow/compiler/jit/kernels:xla_ops", + "//tensorflow/compiler/tf2xla:test_util", "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/tf2xla/cc:xla_jit_ops", + "//tensorflow/compiler/tf2xla/cc:xla_ops", "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", @@ -495,6 +520,9 @@ tf_cc_test( "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core:testlib", + "//tensorflow/core/grappler/optimizers/data:graph_utils", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", ], ) @@ -512,7 +540,7 @@ tf_cc_test( "//tensorflow/cc:cc_ops_internal", "//tensorflow/cc:function_ops", "//tensorflow/cc:ops", - "//tensorflow/compiler/jit/kernels:xla_launch_op", + "//tensorflow/compiler/jit/kernels:xla_ops", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/core:core_cpu", @@ -566,6 +594,7 @@ cc_library( "//tensorflow/core/grappler:grappler_item", "//tensorflow/core/grappler/optimizers:custom_graph_optimizer", "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry", + "@com_google_absl//absl/strings", ], ) @@ -586,6 +615,46 @@ tf_cuda_cc_test( ], ) +cc_library( + name = "node_matchers", + testonly = True, + srcs = ["node_matchers.cc"], + hdrs = ["node_matchers.h"], + deps = [ + "//tensorflow/cc:ops", + "//tensorflow/compiler/xla:test", + "//tensorflow/core:framework", + "//tensorflow/core:graph", + "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", + ], +) + +tf_cc_test( + name = "node_matchers_test", + srcs = ["node_matchers_test.cc"], + deps = [ + ":node_matchers", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:cc_ops_internal", + "//tensorflow/cc:ops", + "//tensorflow/core:ops", + "//tensorflow/core:test_main", + ], +) + +tf_custom_op_py_library( + name = "xla_ops_py", + kernels = ["//tensorflow/compiler/jit/ops:xla_ops"], + visibility = [ + ":friends", + ], + deps = ["//tensorflow/compiler/jit/ops:xla_ops_wrapper_py"], +) + # This target can be used by XLA device plugins to prevent circular dependencies, and provides access to all of the required headers for building a device library. cc_header_only_library( name = "xla_jit_headers_lib", diff --git a/tensorflow/compiler/jit/build_xla_launch_ops_pass.cc b/tensorflow/compiler/jit/build_xla_launch_ops_pass.cc deleted file mode 100644 index b17ff589e2597f8d1b5e61f4eaaed7d6ebe6214c..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/jit/build_xla_launch_ops_pass.cc +++ /dev/null @@ -1,142 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/jit/build_xla_launch_ops_pass.h" -#include "tensorflow/compiler/jit/defs.h" -#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" -#include "tensorflow/compiler/tf2xla/dump_graph.h" -#include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/core/common_runtime/function.h" -#include "tensorflow/core/common_runtime/optimization_registry.h" -#include "tensorflow/core/framework/graph_def_util.h" -#include "tensorflow/core/framework/node_def_builder.h" -#include "tensorflow/core/framework/node_def_util.h" -#include "tensorflow/core/graph/algorithm.h" -#include "tensorflow/core/graph/graph.h" -#include "tensorflow/core/graph/graph_constructor.h" -#include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/hash/hash.h" -#include "tensorflow/core/public/version.h" - -namespace tensorflow { - -static Status BuildLaunchNode( - const string& nodename, const string& function_name, - const AttrValueMap& function_attr, const string& device_name, - const DataTypeVector& constant_dtypes, int num_resources, - const DataTypeVector& arg_dtypes, const DataTypeVector& result_dtypes, - Graph* graph, Node** node) { - NodeDef def; - def.set_name(graph->NewName(nodename)); - def.set_op("XlaLaunch"); - def.set_device(device_name); - AddNodeAttr("Tconstants", constant_dtypes, &def); - AddNodeAttr("Targs", arg_dtypes, &def); - AddNodeAttr("Nresources", num_resources, &def); - AddNodeAttr("Tresults", result_dtypes, &def); - NameAttrList function; - function.set_name(function_name); - *function.mutable_attr() = function_attr; - AddNodeAttr("function", function, &def); - - Status status; - *node = graph->AddNode(def, &status); - return status; -} - -static Status ReplaceNodeWithXlaLaunch(Graph* graph, Node* node) { - VLOG(2) << "Replacing " << node->name() << " with XlaLaunch"; - - int num_constant_args, num_resource_args; - TF_RETURN_IF_ERROR( - GetNodeAttr(node->attrs(), kXlaNumConstantArgsAttr, &num_constant_args)); - TF_RETURN_IF_ERROR( - GetNodeAttr(node->attrs(), kXlaNumResourceArgsAttr, &num_resource_args)); - - if (num_constant_args < 0 || num_resource_args < 0 || - num_constant_args + num_resource_args > node->num_inputs()) { - return errors::InvalidArgument( - "Invalid number of constant/resource arguments to XLA kernel."); - } - const int num_nonconst_args = - node->num_inputs() - num_constant_args - num_resource_args; - - DataTypeVector const_dtypes(node->input_types().begin(), - node->input_types().begin() + num_constant_args); - DataTypeVector arg_dtypes( - node->input_types().begin() + num_constant_args, - node->input_types().begin() + num_constant_args + num_nonconst_args); - - // Build a XlaLaunch operator to execute the function body. - Node* launch_node; - TF_RETURN_IF_ERROR(BuildLaunchNode( - graph->NewName(node->name()), node->type_string(), node->def().attr(), - node->requested_device(), const_dtypes, num_resource_args, arg_dtypes, - node->output_types(), graph, &launch_node)); - launch_node->set_assigned_device_name(node->assigned_device_name()); - - // Copy incoming edges to the launch node. - for (const Edge* edge : node->in_edges()) { - if (edge->IsControlEdge()) { - graph->AddControlEdge(edge->src(), launch_node); - } else { - graph->AddEdge(edge->src(), edge->src_output(), launch_node, - edge->dst_input()); - } - } - - // Copy outgoing edges to the launch node. - std::vector out_edges(node->out_edges().begin(), - node->out_edges().end()); - for (const Edge* edge : out_edges) { - Node* dst = edge->dst(); - int src_output = edge->src_output(); - int dst_input = edge->dst_input(); - graph->RemoveEdge(edge); - - if (edge->IsControlEdge()) { - graph->AddControlEdge(launch_node, dst); - } else { - graph->AddEdge(launch_node, src_output, dst, dst_input); - } - } - graph->RemoveNode(node); - - return Status::OK(); -} - -Status BuildXlaLaunchOpsPass::Run(const GraphOptimizationPassOptions& options) { - Graph* graph = options.graph->get(); - - for (Node* n : graph->op_nodes()) { - // In all cases, only try to compile computational nodes. - if (n->IsSend() || n->IsRecv() || n->IsControlFlow()) { - continue; - } - - // Only compile nodes that are marked for compilation by the - // compilation-marking pass (via 'attr_name'). - if (IsXlaCompiledKernel(*n)) { - TF_RETURN_IF_ERROR(ReplaceNodeWithXlaLaunch(graph, n)); - } - } - - if (VLOG_IS_ON(1)) { - dump_graph::DumpGraphToFile("build_xla_launch_ops", *graph, - options.flib_def); - } - return Status::OK(); -} -} // namespace tensorflow diff --git a/tensorflow/compiler/jit/build_xla_ops_pass.cc b/tensorflow/compiler/jit/build_xla_ops_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..054f31ba3352b2215e6b0448c8ec8a70cb98b8e5 --- /dev/null +++ b/tensorflow/compiler/jit/build_xla_ops_pass.cc @@ -0,0 +1,338 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/build_xla_ops_pass.h" +#include "absl/algorithm/container.h" +#include "absl/strings/str_cat.h" +#include "tensorflow/cc/framework/ops.h" +#include "tensorflow/cc/framework/scope_internal.h" +#include "tensorflow/cc/ops/array_ops.h" +#include "tensorflow/cc/ops/const_op.h" +#include "tensorflow/cc/ops/control_flow_ops.h" +#include "tensorflow/compiler/jit/defs.h" +#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" +#include "tensorflow/compiler/jit/legacy_flags/build_xla_ops_pass_flags.h" +#include "tensorflow/compiler/jit/xla_cluster_util.h" +#include "tensorflow/compiler/tf2xla/cc/ops/xla_jit_ops.h" +#include "tensorflow/compiler/tf2xla/dump_graph.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/common_runtime/optimization_registry.h" +#include "tensorflow/core/framework/graph_def_util.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/hash/hash.h" +#include "tensorflow/core/public/version.h" + +namespace tensorflow { +namespace { +void MoveOutgoingEdges(Graph* g, Node* old_node, Node* new_node) { + std::vector out_edges(old_node->out_edges().begin(), + old_node->out_edges().end()); + for (const Edge* edge : out_edges) { + // TODO(sanjoy): This does not update NodeDef inputs. To be able to update + // NodeDef inputs we first need to fix encapsulate_subgraphs_pass to fix up + // the NodeDef inputs to the function call nodes. + g->AddEdge(new_node, edge->src_output(), edge->dst(), edge->dst_input()); + g->RemoveEdge(edge); + } +} + +// Returns a data value that is dead iff `control` is dead. +Output ControlToData(const Scope& scope, Node* control) { + Output data = ops::Const(scope.WithOpName("ctrl_as_data"), + Tensor(DT_BOOL, TensorShape({0}))); + scope.graph()->AddControlEdge(control, data.node()); + return Output(data.node()); +} + +// Returns an operation that can be control-depended on that is dead iff `data` +// is dead. +Operation DataToControl(const Scope& scope, Output data) { + return Operation( + ops::Identity(scope.WithOpName("data_as_ctrl"), data).node()); +} + +// Replaces each outgoing edge from `old_node` with a merge node that merges in +// the corresponding output from `new_node`. +void MergeOutgoingDataEdges(const Scope& s, Node* old_node, Node* new_node) { + if (!s.status().ok()) { + return; + } + + std::vector merged_outputs(old_node->num_outputs(), Output(nullptr)); + + std::vector data_edges; + absl::c_copy_if(old_node->out_edges(), std::back_inserter(data_edges), + [](const Edge* e) { return !e->IsControlEdge(); }); + + for (const Edge* e : data_edges) { + int oidx = e->src_output(); + Output merged_output = merged_outputs[oidx]; + if (merged_output.node() == nullptr) { + ops::Merge merge_op(s.WithOpName(absl::StrCat("merge_oidx_", oidx)), + {Output(old_node, oidx), Output(new_node, oidx)}); + merged_output = merged_outputs[oidx] = merge_op.output; + } + + Node* dst = e->dst(); + int dst_idx = e->dst_input(); + + s.graph()->RemoveEdge(e); + s.graph()->AddEdge(merged_output.node(), merged_output.index(), dst, + dst_idx); + } +} + +// Replaces each control successor of `old_node` to execute whenever either +// `old_node` or `new_node` is executed. +void MergeOutgoingControlEdges(const Scope& s, Node* old_node, Node* new_node) { + if (!s.status().ok()) { + return; + } + + std::vector ctrl_edges; + absl::c_copy_if(old_node->out_edges(), std::back_inserter(ctrl_edges), + [](const Edge* e) { return e->IsControlEdge(); }); + + if (ctrl_edges.empty()) { + return; + } + + // We can't merge control edges directly so we instead first "convert" them to + // normal values that can be merged, merge the values and then "convert" the + // merged value back into control. + // + // NB! We need to copy out the outgoing control edges before constructing + // old_ctrl_as_data otherwise the control edge from old_node to the constant + // in ControlToData will be present in ctrl_edges. + + Output old_ctrl_as_data = ControlToData(s, old_node); + Output new_ctrl_as_data = ControlToData(s, new_node); + + ops::Merge ctrl_merge_as_data(s.WithOpName("ctrl_merge"), + {old_ctrl_as_data, new_ctrl_as_data}); + Operation ctrl_merge = DataToControl(s, ctrl_merge_as_data.output); + + for (const Edge* e : ctrl_edges) { + s.graph()->AddControlEdge(ctrl_merge.node(), e->dst()); + s.graph()->RemoveControlEdge(e); + } +} + +struct XlaClusterInfo { + std::vector constant_inputs; + std::vector non_constant_inputs; + std::vector resource_inputs; + NameAttrList function; +}; + +Output IncomingEdgeAsOutput(const Edge* e) { + return Output(e->src(), e->src_output()); +} + +Status GetXlaClusterInfo(Node* n, XlaClusterInfo* result) { + int num_constant_inputs, num_resource_inputs; + TF_RETURN_IF_ERROR( + GetNodeAttr(n->attrs(), kXlaNumConstantArgsAttr, &num_constant_inputs)); + TF_RETURN_IF_ERROR( + GetNodeAttr(n->attrs(), kXlaNumResourceArgsAttr, &num_resource_inputs)); + + if (num_constant_inputs < 0 || num_resource_inputs < 0 || + num_constant_inputs + num_resource_inputs > n->num_inputs()) { + return errors::InvalidArgument( + "Invalid number of constant/resource arguments to XLA kernel."); + } + + int num_non_constant_inputs = + n->num_inputs() - num_constant_inputs - num_resource_inputs; + + std::vector input_edges_vector; + TF_RETURN_IF_ERROR(n->input_edges(&input_edges_vector)); + absl::Span input_edges(input_edges_vector); + + absl::c_transform(input_edges.subspan(0, num_constant_inputs), + std::back_inserter(result->constant_inputs), + IncomingEdgeAsOutput); + + absl::c_transform( + input_edges.subspan(num_constant_inputs, num_non_constant_inputs), + std::back_inserter(result->non_constant_inputs), IncomingEdgeAsOutput); + + absl::c_transform( + input_edges.subspan(num_constant_inputs + num_non_constant_inputs, + num_resource_inputs), + std::back_inserter(result->resource_inputs), IncomingEdgeAsOutput); + + result->function.set_name(n->type_string()); + *result->function.mutable_attr() = n->def().attr(); + return Status::OK(); +} + +Status CopyIncomingControlEdges(Graph* g, Node* from, Node* to) { + for (const Edge* e : from->in_edges()) { + if (e->IsControlEdge()) { + g->AddControlEdge(e->src(), to); + } + } + + return Status::OK(); +} + +void RemoveAllIncomingControlEdges(Graph* g, Node* n) { + std::vector incoming_ctrl_edges; + absl::c_copy_if(n->in_edges(), std::back_inserter(incoming_ctrl_edges), + [](const Edge* e) { return e->IsControlEdge(); }); + for (const Edge* e : incoming_ctrl_edges) { + g->RemoveControlEdge(e); + } +} + +// Returns true (into `result`) if `node` must be compiled. +Status NodeRequiresCompilation(Node* n, bool* result) { + DeviceType device_type(""); + TF_RETURN_IF_ERROR( + DeviceToDeviceType(n->assigned_device_name(), &device_type)); + const XlaOpRegistry::DeviceRegistration* registration = nullptr; + if (!XlaOpRegistry::GetCompilationDevice(device_type.type(), ®istration)) { + return errors::Internal("Could not find compilation device ", + device_type.type()); + } + *result = registration->requires_compilation; + return Status::OK(); +} + +Status ReplaceNodeWithXlaCompileAndXlaRun( + const FunctionLibraryDefinition& flib_def, bool lazy_compilation_enabled, + Graph* g, Node* n) { + bool requires_compilation; + TF_RETURN_IF_ERROR(NodeRequiresCompilation(n, &requires_compilation)); + if (!lazy_compilation_enabled) { + requires_compilation = true; + } + + Status status; + Scope root = NewInternalScope(g, &status, /*refiner=*/nullptr) + .NewSubScope(n->name()) + .WithDevice(n->requested_device()) + .WithAssignedDevice(n->assigned_device_name()); + + XlaClusterInfo cluster_info; + TF_RETURN_IF_ERROR(GetXlaClusterInfo(n, &cluster_info)); + + ops::_XlaCompile xla_compile(root.WithOpName("xla_compile"), + /*constants=*/cluster_info.constant_inputs, + /*args=*/cluster_info.non_constant_inputs, + /*resources=*/cluster_info.resource_inputs, + /*must_compile=*/requires_compilation, + cluster_info.function); + TF_RETURN_IF_ERROR( + CopyIncomingControlEdges(g, /*from=*/n, /*to=*/xla_compile.key.node())); + + if (requires_compilation) { + // "Strict" compilation: every _XlaCompile invocation must compile the + // cluster. + std::vector xla_run_args = cluster_info.non_constant_inputs; + absl::c_copy(cluster_info.resource_inputs, + std::back_inserter(xla_run_args)); + ops::_XlaRun xla_run(root.WithOpName("xla_run"), xla_run_args, + xla_compile.key, n->output_types()); + + MoveOutgoingEdges(g, /*old_node=*/n, + /*new_node=*/xla_run.operation.node()); + g->RemoveNode(n); + } else { + // "Lazy" compilation: an _XlaCompile invocation may decide not to compile + // the cluster based on profitability heuristics. + + // We generate the following graph: + // + // (use_tf_call, use_xla_run) = + // Switch(pred=xla_compile.compilation_successful, + // value=xla_compile.key) + // + // tf_call_outputs = cluster_N(..., ^use_tf_call) + // xla_run_outputs = _XlaRun(..., key=use_xla_run) + // outputs = Merge(tf_call_outputs, xla_run_outputs). + ops::Switch s(root.WithOpName("predicated_compilation_key"), + xla_compile.key, xla_compile.compilation_successful); + Output predicated_compilation_key = s.output_true; + Output inverse_predicated_compilation_key = s.output_false; + + std::vector xla_run_args = cluster_info.non_constant_inputs; + absl::c_copy(cluster_info.resource_inputs, + std::back_inserter(xla_run_args)); + ops::_XlaRun xla_run(root.WithOpName("xla_run"), xla_run_args, + predicated_compilation_key, n->output_types()); + + MergeOutgoingControlEdges(root, /*old_node=*/n, + /*new_node=*/xla_run.operation.node()); + + MergeOutgoingDataEdges(root, /*old_node=*/n, + /*new_node=*/xla_run.operation.node()); + + TF_RETURN_IF_ERROR(root.status()); + + // We already have a TensorFlow function call into the cluster -- the + // original node we set out to rewrite. We just wire in the correct control + // deps and we're done. + RemoveAllIncomingControlEdges(g, n); + g->AddControlEdge( + DataToControl(root, inverse_predicated_compilation_key).node(), n); + n->ClearAttr(kXlaCompiledKernelAttr); + } + + return Status::OK(); +} +} // namespace + +Status BuildXlaOpsPass::Run(const GraphOptimizationPassOptions& options) { + Graph* graph = options.graph->get(); + + // Copy out the nodes we want to rewrite to avoid modifying the graph while we + // iterate on graph->op_nodes(). + std::vector xla_compiled_kernels; + absl::c_copy_if(graph->op_nodes(), std::back_inserter(xla_compiled_kernels), + [](const Node* n) { + if (n->IsSend() || n->IsRecv() || n->IsControlFlow()) { + return false; + } + + // Only compile nodes that are marked for compilation by the + // compilation-marking pass (via 'attr_name'). + return IsXlaCompiledKernel(*n); + }); + + bool lazy_compilation_enabled = enable_lazy_compilation_ + ? *enable_lazy_compilation_ + : legacy_flags::GetBuildXlaOpsPassFlags() + .tf_xla_enable_lazy_compilation; + + for (Node* n : xla_compiled_kernels) { + TF_RETURN_IF_ERROR(ReplaceNodeWithXlaCompileAndXlaRun( + *options.flib_def, lazy_compilation_enabled, graph, n)); + } + + if (VLOG_IS_ON(1)) { + dump_graph::DumpGraphToFile("build_xla_ops", *graph, options.flib_def); + } + + return Status::OK(); +} +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/build_xla_ops_pass.h b/tensorflow/compiler/jit/build_xla_ops_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..58f7c4b3a0d1472f602e8234f9f08c23dfe78a34 --- /dev/null +++ b/tensorflow/compiler/jit/build_xla_ops_pass.h @@ -0,0 +1,44 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_BUILD_XLA_OPS_PASS_H_ +#define TENSORFLOW_COMPILER_JIT_BUILD_XLA_OPS_PASS_H_ + +#include "absl/types/optional.h" +#include "tensorflow/core/common_runtime/optimization_registry.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +// Adds _XlaCompile and _XlaRun operations to the TF graph that compiles and +// executes (using XLA) TF function calls marked with "_XlaCompiledKernel". +class BuildXlaOpsPass : public GraphOptimizationPass { + public: + // If enable_lazy_compilation is not nullopt then *enable_lazy_compilation + // overrides --tf_xla_enable_lazy_compilation flag in deciding whether lazy + // compilation is enabled. + explicit BuildXlaOpsPass( + absl::optional enable_lazy_compilation = absl::nullopt) + : enable_lazy_compilation_(enable_lazy_compilation) {} + + Status Run(const GraphOptimizationPassOptions& options) override; + + private: + absl::optional enable_lazy_compilation_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_BUILD_XLA_OPS_PASS_H_ diff --git a/tensorflow/compiler/jit/build_xla_ops_pass_test.cc b/tensorflow/compiler/jit/build_xla_ops_pass_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..11df946cc186660242574c2644463a26ead44f1f --- /dev/null +++ b/tensorflow/compiler/jit/build_xla_ops_pass_test.cc @@ -0,0 +1,234 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/build_xla_ops_pass.h" + +#include "tensorflow/cc/framework/ops.h" +#include "tensorflow/cc/ops/array_ops.h" +#include "tensorflow/cc/ops/resource_variable_ops.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/compiler/jit/defs.h" +#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" +#include "tensorflow/compiler/jit/node_matchers.h" +#include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/grappler/optimizers/data/graph_utils.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/public/session_options.h" + +namespace tensorflow { +namespace { + +class BuildXlaOpsTest : public ::testing::Test { + protected: + void SetUp() override { + // This is needed to register the XLA_* devices. + CHECK(DeviceFactory::AddDevices( + SessionOptions(), "/job:localhost/replica:0/task:0", &devices_) + .ok()); + } + + void TearDown() override { + for (Device* device : devices_) { + delete device; + } + } + + private: + std::vector devices_; +}; + +using ::tensorflow::testing::FindNodeByName; +using ::tensorflow::testing::matchers::Attr; +using ::tensorflow::testing::matchers::CtrlDeps; +using ::tensorflow::testing::matchers::Inputs; +using ::tensorflow::testing::matchers::NodeWith; +using ::tensorflow::testing::matchers::Op; +using ::tensorflow::testing::matchers::Out; +using ::testing::_; + +Status BuildXlaOps(const Scope& s, std::unique_ptr* result) { + auto graph = absl::make_unique(OpRegistry::Global()); + TF_RETURN_IF_ERROR(s.ToGraph(graph.get())); + + // Assign all nodes to the CPU device. + static const char* kCpuDevice = "/job:localhost/replica:0/task:0/cpu:0"; + for (Node* n : graph->nodes()) { + if (n->requested_device().empty()) { + n->set_assigned_device_name(kCpuDevice); + } else { + n->set_assigned_device_name(n->requested_device()); + } + } + + GraphOptimizationPassOptions opt_options; + opt_options.graph = &graph; + BuildXlaOpsPass pass(/*enable_lazy_compilation=*/true); + TF_RETURN_IF_ERROR(pass.Run(opt_options)); + VLOG(3) << graph->ToGraphDefDebug().DebugString(); + *result = std::move(graph); + return Status::OK(); +} + +Status MakeXlaCompiledKernel(Graph* graph, const string& callee_name, + const string& node_name, int num_constant_args, + int num_resource_args, Node** result) { + NodeDef call_node; + call_node.set_name(node_name); + call_node.set_op(callee_name); + AddNodeAttr(kXlaCompiledKernelAttr, true, &call_node); + AddNodeAttr(kXlaNumConstantArgsAttr, num_constant_args, &call_node); + AddNodeAttr(kXlaNumResourceArgsAttr, num_resource_args, &call_node); + Status s; + *result = graph->AddNode(call_node, &s); + return s; +} + +Status MakeXlaCompiledKernel(Graph* graph, const string& callee_name, + const string& node_name, Node** result) { + return MakeXlaCompiledKernel(graph, callee_name, node_name, + /*num_constant_args=*/0, /*num_resource_args=*/0, + result); +} + +Node* MakeWrite(const Scope& scope, Output value_to_write, const string& id) { + Output var_handle = ops::VarHandleOp(scope.WithOpName("Var_" + id), DT_FLOAT, + TensorShape({})); + ops::AssignVariableOp assign_op(scope.WithOpName("Assignee_" + id), + var_handle, value_to_write); + return assign_op.operation.node(); +} + +Node* MakeWrite(const Scope& scope, const string& id) { + return MakeWrite( + scope, ops::Const(scope.WithOpName("ValueToAssign" + id), 1.0f), id); +} + +FunctionDefLibrary CreateFunctionDefLibWithConstFunction(const string& name) { + FunctionDefLibrary flib_def; + FunctionDef func = FunctionDefHelper::Create( + /*function_name=*/name, /*in_def=*/{}, /*out_def=*/{"out: float"}, + /*attr_def*/ + {}, /*node_def=*/{FunctionDefHelper::Const("one", 1.0f)}, + /*ret_def=*/{{"out", "out:output:0"}}); + *flib_def.add_function() = std::move(func); + return flib_def; +} + +TEST_F(BuildXlaOpsTest, ControlDepsPreserved) { + const char* kXlaDeviceName = "/job:worker/replica:0/task:0/device:XLA_CPU:0"; + Scope root = Scope::NewRootScope().WithDevice(kXlaDeviceName).ExitOnError(); + + FunctionDefLibrary flib_def = + CreateFunctionDefLibWithConstFunction("cluster_0"); + TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def)); + Node* call; + TF_ASSERT_OK(MakeXlaCompiledKernel(root.graph(), "cluster_0", "C", &call)); + call->set_requested_device(kXlaDeviceName); + Node* write_op = MakeWrite(root, "write"); + root.graph()->AddControlEdge(call, write_op); + + std::unique_ptr graph; + TF_ASSERT_OK(BuildXlaOps(root, &graph)); + + Node* write_op_new = FindNodeByName(graph.get(), write_op->name()); + ASSERT_NE(write_op_new, nullptr); + EXPECT_THAT(write_op_new, NodeWith(CtrlDeps(NodeWith(Op("_XlaRun"))))); +} + +TEST_F(BuildXlaOpsTest, CleanFailureOnBogusAttr) { + Scope root = Scope::NewRootScope().ExitOnError(); + + FunctionDefLibrary flib_def = + CreateFunctionDefLibWithConstFunction("cluster_0"); + TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def)); + + Node* call; + TF_ASSERT_OK( + MakeXlaCompiledKernel(root.graph(), "cluster_0", "C", 100, 100, &call)); + + Node* write_op = MakeWrite(root, "write"); + root.graph()->AddControlEdge(call, write_op); + + std::unique_ptr graph; + Status failure_status = BuildXlaOps(root, &graph); + ASSERT_FALSE(failure_status.ok()); + EXPECT_EQ(failure_status.code(), error::INVALID_ARGUMENT); +} + +TEST_F(BuildXlaOpsTest, OnNonXlaDevice) { + Scope root = Scope::NewRootScope().ExitOnError(); + + FunctionDefLibrary flib_def = + CreateFunctionDefLibWithConstFunction("cluster_0"); + TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def)); + + Node* call; + TF_ASSERT_OK(MakeXlaCompiledKernel(root.graph(), "cluster_0", "C", &call)); + TF_ASSERT_OK(root.DoShapeInference(call)); + + Node* write_op = MakeWrite(root, Output(call), "write_result"); + + auto xla_compile = NodeWith(Op("_XlaCompile"), Attr("must_compile", false)); + auto predicated_compilation_key = + NodeWith(Op("Switch"), Inputs(Out(0, xla_compile), Out(1, xla_compile))); + auto xla_run = + NodeWith(Op("_XlaRun"), Inputs(Out(1, predicated_compilation_key))); + auto tf_call = + NodeWith(Op("cluster_0"), + CtrlDeps(NodeWith(Op("Identity"), + Inputs(Out(0, predicated_compilation_key))))); + auto merge = NodeWith(Op("Merge"), Inputs(Out(tf_call), Out(xla_run))); + auto assign_var = NodeWith(Op("AssignVariableOp"), Inputs(_, Out(merge))); + + std::unique_ptr graph; + TF_ASSERT_OK(BuildXlaOps(root, &graph)); + + Node* write_op_new = FindNodeByName(graph.get(), write_op->name()); + ASSERT_NE(write_op_new, nullptr); + EXPECT_THAT(write_op_new, assign_var); +} + +TEST_F(BuildXlaOpsTest, OnXlaDevice) { + const char* kXlaDeviceName = "/job:worker/replica:0/task:0/device:XLA_CPU:0"; + Scope root = Scope::NewRootScope().WithDevice(kXlaDeviceName).ExitOnError(); + + FunctionDefLibrary flib_def = + CreateFunctionDefLibWithConstFunction("cluster_0"); + TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def)); + + Node* call; + TF_ASSERT_OK(MakeXlaCompiledKernel(root.graph(), "cluster_0", "C", &call)); + call->set_requested_device(kXlaDeviceName); + TF_ASSERT_OK(root.DoShapeInference(call)); + + Node* write_op = MakeWrite(root, Output(call), "write_result"); + + std::unique_ptr graph; + TF_ASSERT_OK(BuildXlaOps(root, &graph)); + + auto xla_op = + NodeWith(Op("_XlaRun"), Inputs(Out(NodeWith(Op("_XlaCompile"))))); + auto assign_var = + NodeWith(Op("AssignVariableOp"), Inputs(Out(NodeWith()), Out(xla_op))); + + Node* write_op_new = FindNodeByName(graph.get(), write_op->name()); + ASSERT_NE(write_op_new, nullptr); + EXPECT_THAT(write_op_new, assign_var); +} +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/create_xla_launch_op.cc b/tensorflow/compiler/jit/create_xla_launch_op.cc index a7f8a5613c019b355759a53e8de304eddafb3257..6f1ff85f24a4c1fd3e6d54fcff9f8868aee6f750 100644 --- a/tensorflow/compiler/jit/create_xla_launch_op.cc +++ b/tensorflow/compiler/jit/create_xla_launch_op.cc @@ -16,7 +16,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "tensorflow/compiler/jit/defs.h" -#include "tensorflow/compiler/jit/kernels/xla_launch_op.h" +#include "tensorflow/compiler/jit/kernels/xla_ops.h" #include "tensorflow/compiler/jit/mark_for_compilation_pass.h" #include "tensorflow/compiler/tf2xla/const_analysis.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" @@ -209,8 +209,13 @@ Status CreateXlaLaunchOp(FunctionLibraryRuntime* flr, const NodeDef& node_def, // device memory. // XlaLaunch kernel keeps all outputs (including constants, which it copies), - // in device memory + // in device memory except for resources. MemoryTypeVector output_memory_types(fbody->ret_types.size(), DEVICE_MEMORY); + for (int i = 0; i < fbody->ret_types.size(); ++i) { + if (fbody->ret_types[i] == DT_RESOURCE) { + output_memory_types[i] = HOST_MEMORY; + } + } // Create the kernel. NameAttrList function; diff --git a/tensorflow/compiler/jit/deadness_analysis.cc b/tensorflow/compiler/jit/deadness_analysis.cc index fe28502f69d34e7c075bdf85afd2473024b4081d..b7ae7fbeb3912882368dc828e8d6fcd50735b04e 100644 --- a/tensorflow/compiler/jit/deadness_analysis.cc +++ b/tensorflow/compiler/jit/deadness_analysis.cc @@ -14,11 +14,14 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/jit/deadness_analysis.h" +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/strings/str_join.h" #include "tensorflow/compiler/jit/deadness_analysis_internal.h" +#include "tensorflow/compiler/jit/xla_cluster_util.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/tensor_id.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/hash/hash.h" // ALGORITHM OVERVIEW @@ -108,7 +111,7 @@ class Predicate { virtual string ToString() const = 0; int64 hash() const { return hash_; } - virtual gtl::ArraySlice GetOperands() const = 0; + virtual absl::Span GetOperands() const = 0; virtual Kind kind() const = 0; virtual ~Predicate() {} @@ -129,7 +132,7 @@ class Predicate { }; int64 HashPredicateSequence(Predicate::Kind kind, - gtl::ArraySlice preds) { + absl::Span preds) { int64 hash = ::tensorflow::hash()(kind); for (Predicate* pred : preds) { hash = Hash64Combine(hash, pred->hash()); @@ -154,13 +157,15 @@ class AndPredicate : public Predicate { std::back_inserter(operands_str), [](Predicate* pred) { return pred->ToString(); }); - return strings::StrCat("(", absl::StrJoin(operands_str, " & "), ")"); + return absl::StrCat("(", absl::StrJoin(operands_str, " & "), ")"); } Kind kind() const override { return Kind::kAnd; } - gtl::ArraySlice GetOperands() const override { return operands_; } - gtl::ArraySlice operands() const { return operands_; } + absl::Span GetOperands() const override { + return operands_; + } + absl::Span operands() const { return operands_; } private: std::vector operands_; @@ -183,12 +188,14 @@ class OrPredicate : public Predicate { std::back_inserter(operands_str), [](Predicate* pred) { return pred->ToString(); }); - return strings::StrCat("(", absl::StrJoin(operands_str, " | "), ")"); + return absl::StrCat("(", absl::StrJoin(operands_str, " | "), ")"); } Kind kind() const override { return Kind::kOr; } - gtl::ArraySlice GetOperands() const override { return operands_; } - gtl::ArraySlice operands() const { return operands_; } + absl::Span GetOperands() const override { + return operands_; + } + absl::Span operands() const { return operands_; } private: std::vector operands_; @@ -202,12 +209,14 @@ class NotPredicate : public Predicate { operands_({operand}) {} string ToString() const override { - return strings::StrCat("~", operand()->ToString()); + return absl::StrCat("~", operand()->ToString()); } Kind kind() const override { return Kind::kNot; } Predicate* operand() const { return operands_[0]; } - gtl::ArraySlice GetOperands() const override { return operands_; } + absl::Span GetOperands() const override { + return operands_; + } private: std::array operands_; @@ -234,13 +243,15 @@ class AndRecurrencePredicate : public Predicate { Predicate* step() const { return operands_[1]; } string ToString() const override { - return strings::StrCat("{", start()->ToString(), ",&,", step()->ToString(), - "}"); + return absl::StrCat("{", start()->ToString(), ",&,", step()->ToString(), + "}"); } Kind kind() const override { return Kind::kAndRecurrence; } - gtl::ArraySlice GetOperands() const override { return operands_; } + absl::Span GetOperands() const override { + return operands_; + } private: std::array operands_; @@ -259,12 +270,12 @@ class SymbolPredicate : public Predicate { must_be_true_(must_be_true) {} string ToString() const override { - return must_be_true() ? strings::StrCat("*", tensor_id_.ToString()) + return must_be_true() ? absl::StrCat("*", tensor_id_.ToString()) : tensor_id_.ToString(); } Kind kind() const override { return Kind::kSymbol; } - gtl::ArraySlice GetOperands() const override { return {}; } + absl::Span GetOperands() const override { return {}; } // If `must_be_true()` is true this SymbolPredicate represents the proposition // "tensor_id() is live and evaluates to true". @@ -288,7 +299,7 @@ class SymbolPredicate : public Predicate { template /*static*/ void Predicate::Visit(Predicate* p, const FunctionTy& func) { - gtl::FlatSet visited; + absl::flat_hash_set visited; std::vector stack; stack.push_back(p); @@ -313,11 +324,11 @@ template // them. class PredicateFactory { public: - Predicate* MakeAndPredicate(gtl::ArraySlice operands) { + Predicate* MakeAndPredicate(absl::Span operands) { return MakeAndOrImpl(operands, /*is_and=*/true); } - Predicate* MakeOrPredicate(gtl::ArraySlice operands) { + Predicate* MakeOrPredicate(absl::Span operands) { return MakeAndOrImpl(operands, /*is_and=*/false); } @@ -374,7 +385,9 @@ class PredicateFactory { new PredicateT(std::forward(args)...)); } - Predicate* MakeAndOrImpl(gtl::ArraySlice operands, bool is_and); + Predicate* MakeAndOrImpl(absl::Span operands, bool is_and); + Predicate* MakeInternedAndOr(std::vector simplified_ops, + Predicate::Kind pred_kind); // Predicate instances are interned, meaning that there is only a single // instance of a Predicate object with a given content. This makes checking @@ -387,7 +400,7 @@ class PredicateFactory { // for the owning pointers to predicate instances. using SignatureForAndOr = - std::pair>; + std::pair>; using SignatureForNot = Predicate*; using SignatureForAndRec = std::pair; using SignatureForSymbol = std::pair; @@ -409,24 +422,53 @@ class PredicateFactory { } }; - gtl::FlatMap, - HashSignatureForAndOr> + absl::flat_hash_map, + HashSignatureForAndOr> interned_and_or_instances_; - gtl::FlatMap> + absl::flat_hash_map> interned_not_instances_; - gtl::FlatMap> + absl::flat_hash_map> interned_and_rec_instances_; - gtl::FlatMap, - HashSignatureForSymbol> + absl::flat_hash_map, + HashSignatureForSymbol> interned_symbol_instances_; }; +Predicate* PredicateFactory::MakeInternedAndOr( + std::vector simplified_ops, Predicate::Kind pred_kind) { + std::stable_sort( + simplified_ops.begin(), simplified_ops.end(), + [](Predicate* a, Predicate* b) { return a->hash() < b->hash(); }); + + auto it = interned_and_or_instances_.find({pred_kind, simplified_ops}); + if (it != interned_and_or_instances_.end()) { + return it->second.get(); + } + + simplified_ops.shrink_to_fit(); + // NB! Because we'll use a non-owning reference to simplified_ops in the + // key for interned_and_or_instances_ we need to be careful to std::move() + // it all the way through. + absl::Span operands_slice = simplified_ops; + std::unique_ptr new_pred = + pred_kind == Predicate::Kind::kAnd + ? Make(std::move(simplified_ops)) + : Make(std::move(simplified_ops)); + + Predicate* new_pred_ptr = new_pred.get(); + interned_and_or_instances_.emplace( + SignatureForAndOr(pred_kind, operands_slice), std::move(new_pred)); + return new_pred_ptr; +} + // Common code to create AndPredicate or OrPredicate instances. -Predicate* PredicateFactory::MakeAndOrImpl(gtl::ArraySlice operands, - bool is_and) { +Predicate* PredicateFactory::MakeAndOrImpl( + absl::Span operands, bool is_and) { Predicate::Kind pred_kind = is_and ? Predicate::Kind::kAnd : Predicate::Kind::kOr; - gtl::FlatSet simplified_ops_set; + Predicate::Kind other_pred_kind = + is_and ? Predicate::Kind::kOr : Predicate::Kind::kAnd; + absl::flat_hash_set simplified_ops_set; std::vector simplified_ops; for (Predicate* op : operands) { // Simplify A&A => A and A|A => A. @@ -451,7 +493,7 @@ Predicate* PredicateFactory::MakeAndOrImpl(gtl::ArraySlice operands, } // Simplify "A&~A=>False" and "A|~A=>True". - gtl::FlatSet negated_ops; + absl::flat_hash_set negated_ops; for (Predicate* op : simplified_ops) { if (op->kind() == Predicate::Kind::kNot) { negated_ops.insert(dynamic_cast(*op).operand()); @@ -464,30 +506,63 @@ Predicate* PredicateFactory::MakeAndOrImpl(gtl::ArraySlice operands, } } - std::stable_sort( - simplified_ops.begin(), simplified_ops.end(), - [](Predicate* a, Predicate* b) { return a->hash() < b->hash(); }); + // If all ops contain the same subop, then factor it out thanks to the + // distributive property. Such as: + // - (A & B) | (A & C) | (A & D) => A & (B | C | D) + // - (A | B) & (A | C) & (A | D) => A | (B & C & D) + // + // First find any predicates contained in all subops. + std::vector common_inner_operands; + absl::flat_hash_set common_inner_operands_set; + for (Predicate* op : simplified_ops) { + if (op->kind() != other_pred_kind) { + common_inner_operands.clear(); + break; + } - auto it = interned_and_or_instances_.find({pred_kind, simplified_ops}); - if (it == interned_and_or_instances_.end()) { - simplified_ops.shrink_to_fit(); - // NB! Because we'll use a non-owning reference to simplified_ops in the - // key for interned_and_or_instances_ we need to be careful to std::move() - // it all the way through. - gtl::ArraySlice operands_slice = simplified_ops; - std::unique_ptr new_pred = - is_and ? Make(std::move(simplified_ops)) - : Make(std::move(simplified_ops)); + if (common_inner_operands.empty()) { + common_inner_operands.insert(common_inner_operands.end(), + op->GetOperands().begin(), + op->GetOperands().end()); + } else { + std::vector sub_ops_intersection; + common_inner_operands.clear(); + absl::c_copy_if(op->GetOperands(), + std::back_inserter(common_inner_operands), + [&](Predicate* sub_op) { + return common_inner_operands_set.count(sub_op) == 1; + }); + } + if (common_inner_operands.empty()) break; + common_inner_operands_set.clear(); + common_inner_operands_set.insert(common_inner_operands.begin(), + common_inner_operands.end()); + } - Predicate* new_pred_ptr = new_pred.get(); - CHECK(interned_and_or_instances_ - .emplace(SignatureForAndOr(pred_kind, operands_slice), - std::move(new_pred)) - .second); - return new_pred_ptr; - } else { - return it->second.get(); + if (common_inner_operands.empty()) { + return MakeInternedAndOr(std::move(simplified_ops), pred_kind); + } + + // For all predicates that can be factored out, remove them and recreate the + // subops. + std::vector factored_ops; + for (Predicate* op : simplified_ops) { + std::vector new_sub_op_ops; + absl::c_copy_if(op->GetOperands(), std::back_inserter(new_sub_op_ops), + [&](Predicate* sub_op) { + return std::find(common_inner_operands.begin(), + common_inner_operands.end(), + sub_op) == common_inner_operands.end(); + }); + factored_ops.push_back(MakeAndOrImpl(new_sub_op_ops, !is_and)); } + + Predicate* new_inner_op = MakeAndOrImpl(factored_ops, is_and); + std::vector outer_ops; + outer_ops.push_back(new_inner_op); + outer_ops.insert(outer_ops.end(), common_inner_operands.begin(), + common_inner_operands.end()); + return MakeAndOrImpl(outer_ops, !is_and); } class DeadnessAnalysisImpl : public DeadnessAnalysis { @@ -496,15 +571,17 @@ class DeadnessAnalysisImpl : public DeadnessAnalysis { : graph_(*graph), vlog_(VLOG_IS_ON(2)) {} Status Populate(); - Status PopulateWithReversePostOrder(gtl::ArraySlice rpo); + Status PopulateWithReversePostOrder(absl::Span rpo); bool HasInputsWithMismatchingDeadness(const Node& node) override; void Print() const override; - gtl::FlatMap PredicateMapAsString() const; + absl::flat_hash_map PredicateMapAsString() + const; private: enum class EdgeKind { kDataAndControl, kDataOnly, kControlOnly }; - std::vector GetIncomingPreds(Node* n, EdgeKind edge_kind); + Status GetInputPreds(Node* n, EdgeKind edge_kind, + std::vector* result); // Sets the predicate for output `output_idx` of `n` to `pred`. Sets the i'th // bit of `should_revisit` if `pred` is different from the current predicate @@ -527,7 +604,7 @@ class DeadnessAnalysisImpl : public DeadnessAnalysis { } } - void SetPredicate(Node* n, gtl::ArraySlice output_idxs, Predicate* pred, + void SetPredicate(Node* n, absl::Span output_idxs, Predicate* pred, std::vector* should_revisit) { for (int output_idx : output_idxs) { SetPredicate(n, output_idx, pred, should_revisit); @@ -541,7 +618,7 @@ class DeadnessAnalysisImpl : public DeadnessAnalysis { Status HandleNode(Node* n, std::vector* should_revisit); const Graph& graph_; - gtl::FlatMap predicate_map_; + absl::flat_hash_map predicate_map_; PredicateFactory predicate_factory_; bool vlog_; }; @@ -550,9 +627,10 @@ TensorId InputEdgeToTensorId(const Edge* e) { return TensorId(e->src()->name(), e->src_output()); } -std::vector DeadnessAnalysisImpl::GetIncomingPreds( - Node* n, DeadnessAnalysisImpl::EdgeKind edge_kind) { - std::vector incoming_preds; +Status DeadnessAnalysisImpl::GetInputPreds( + Node* n, DeadnessAnalysisImpl::EdgeKind edge_kind, + std::vector* result) { + result->clear(); for (const Edge* in_edge : n->in_edges()) { bool should_process = edge_kind == EdgeKind::kDataAndControl || @@ -561,17 +639,27 @@ std::vector DeadnessAnalysisImpl::GetIncomingPreds( if (should_process) { auto it = predicate_map_.find(InputEdgeToTensorId(in_edge)); - CHECK(it != predicate_map_.end()) << n->name(); - incoming_preds.push_back(it->second); + if (it == predicate_map_.end()) { + GraphCycles graph_cycles; + TF_RETURN_IF_ERROR(CreateCycleDetectionGraph(&graph_, &graph_cycles)); + + // If we didn't return with an error above then the graph is probably + // fine and we have a bug in deadness analysis. + return errors::Internal("Could not find input ", in_edge->DebugString(), + " to ", n->name(), + " when visiting the graph in post-order. Most " + "likely indicates a bug in deadness analysis."); + } + result->push_back(it->second); } } - return incoming_preds; + return Status::OK(); } Status DeadnessAnalysisImpl::HandleSwitch(Node* n, std::vector* should_revisit) { - std::vector input_preds = - GetIncomingPreds(n, EdgeKind::kDataAndControl); + std::vector input_preds; + TF_RETURN_IF_ERROR(GetInputPreds(n, EdgeKind::kDataAndControl, &input_preds)); const Edge* pred_edge; TF_RETURN_IF_ERROR(n->input_edge(1, &pred_edge)); Predicate* true_switch = predicate_factory_.MakeSymbolPredicate( @@ -600,17 +688,31 @@ Status DeadnessAnalysisImpl::HandleSwitch(Node* n, } namespace { -const Edge* FindUniqueBackedge(Node* merge) { +Status CreateMultipleNextIterationInputsError(Node* merge) { + std::vector backedges; + for (const Edge* backedge : merge->in_edges()) { + if (backedge->src()->IsNextIteration()) { + backedges.push_back(absl::StrCat(" ", SummarizeNode(*backedge->src()))); + } + } + return errors::InvalidArgument( + "Multiple NextIteration inputs to merge node ", SummarizeNode(*merge), + ": \n", absl::StrJoin(backedges, "\n"), + "\nMerge nodes can have at most one incoming NextIteration edge."); +} + +Status FindUniqueBackedge(Node* merge, const Edge** result) { + *result = nullptr; CHECK(merge->IsMerge()); - const Edge* result = nullptr; for (const Edge* e : merge->in_edges()) { if (e->src()->IsNextIteration()) { - CHECK_EQ(result, nullptr) - << "Multiple backedges to " << merge->DebugString(); - result = e; + if (*result != nullptr) { + return CreateMultipleNextIterationInputsError(merge); + } + *result = e; } } - return result; + return Status::OK(); } // If `backedge_predicate` is equal to `symbolic_predicate` & Step where Step @@ -625,7 +727,7 @@ Predicate* DeduceStepPredicate(PredicateFactory* predicate_factory, } std::vector and_ops; - gtl::ArraySlice recurrent_pred_ops = + absl::Span recurrent_pred_ops = backedge_predicate->GetOperands(); bool found_sym = false; @@ -689,9 +791,12 @@ Status DeadnessAnalysisImpl::HandleMerge(Node* n, return Status::OK(); } + std::vector input_preds; + TF_RETURN_IF_ERROR(GetInputPreds(n, EdgeKind::kDataOnly, &input_preds)); + // We're visiting this merge for the first time and it is a acyclic merge. - Predicate* input_data_pred = predicate_factory_.MakeOrPredicate( - GetIncomingPreds(n, EdgeKind::kDataOnly)); + Predicate* input_data_pred = + predicate_factory_.MakeOrPredicate(input_preds); SetPredicate(n, {0, 1, Graph::kControlSlot}, input_data_pred, should_revisit); return Status::OK(); @@ -702,7 +807,9 @@ Status DeadnessAnalysisImpl::HandleMerge(Node* n, // of an unvisited backedge. Try to pattern match the predicate expression // for that backedge (which should be visited now) into an and recurrence // for the merge node. - if (const Edge* unique_backedge = FindUniqueBackedge(n)) { + const Edge* unique_backedge; + TF_RETURN_IF_ERROR(FindUniqueBackedge(n, &unique_backedge)); + if (unique_backedge) { if (Predicate* step = DeduceStepPredicate( &predicate_factory_, it->second, predicate_map_[InputEdgeToTensorId(unique_backedge)])) { @@ -733,8 +840,8 @@ Status DeadnessAnalysisImpl::HandleRecv(Node* n, std::vector* should_revisit) { // In addition to being alive or dead based on the inputs, a _Recv can also // acquire a dead signal from a _Send. - std::vector input_preds = - GetIncomingPreds(n, EdgeKind::kDataAndControl); + std::vector input_preds; + TF_RETURN_IF_ERROR(GetInputPreds(n, EdgeKind::kDataAndControl, &input_preds)); input_preds.push_back(predicate_factory_.MakeSymbolPredicate( TensorId(n->name(), 0), /*must_be_true=*/false)); SetPredicate(n, {0, Graph::kControlSlot}, @@ -746,8 +853,9 @@ Status DeadnessAnalysisImpl::HandleRecv(Node* n, Status DeadnessAnalysisImpl::HandleGeneric(Node* n, std::vector* should_revisit) { // Generally nodes are alive iff all their inputs are alive. - Predicate* pred = predicate_factory_.MakeAndPredicate( - GetIncomingPreds(n, EdgeKind::kDataAndControl)); + std::vector input_preds; + TF_RETURN_IF_ERROR(GetInputPreds(n, EdgeKind::kDataAndControl, &input_preds)); + Predicate* pred = predicate_factory_.MakeAndPredicate(input_preds); for (int output_idx = 0; output_idx < n->num_outputs(); output_idx++) { SetPredicate(n, output_idx, pred, should_revisit); } @@ -784,7 +892,7 @@ Status DeadnessAnalysisImpl::Populate() { } Status DeadnessAnalysisImpl::PopulateWithReversePostOrder( - gtl::ArraySlice rpo) { + absl::Span rpo) { // This an abstract interpretation over the deadness propagation semantics of // the graph executor. // @@ -904,9 +1012,9 @@ DeadnessAnalysis::~DeadnessAnalysis() {} return Status::OK(); } -gtl::FlatMap +absl::flat_hash_map DeadnessAnalysisImpl::PredicateMapAsString() const { - gtl::FlatMap result; + absl::flat_hash_map result; std::vector tensor_ids; for (const auto& kv_pair : predicate_map_) { CHECK(result.insert({kv_pair.first, kv_pair.second->ToString()}).second); @@ -924,7 +1032,7 @@ Status ComputePredicates(const Graph& graph, } Status ComputePredicates(const Graph& graph, - gtl::ArraySlice reverse_post_order, + absl::Span reverse_post_order, PredicateMapTy* out_predicate_map) { DeadnessAnalysisImpl impl(&graph); TF_RETURN_IF_ERROR(impl.PopulateWithReversePostOrder(reverse_post_order)); diff --git a/tensorflow/compiler/jit/deadness_analysis_internal.h b/tensorflow/compiler/jit/deadness_analysis_internal.h index 401d6e406ab3db81d0cbd69b480d5962dab1f357..354782374ad070a3d19ddd68bfb986d5a8285e51 100644 --- a/tensorflow/compiler/jit/deadness_analysis_internal.h +++ b/tensorflow/compiler/jit/deadness_analysis_internal.h @@ -16,15 +16,15 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_JIT_DEADNESS_ANALYSIS_INTERNAL_H_ #define TENSORFLOW_COMPILER_JIT_DEADNESS_ANALYSIS_INTERNAL_H_ +#include "absl/container/flat_hash_map.h" #include "tensorflow/core/graph/tensor_id.h" -#include "tensorflow/core/lib/gtl/flatmap.h" namespace tensorflow { namespace deadness_analysis_internal { // Returns a map describing the predicate each Tensor was mapped to. For // testing purposes only. -using PredicateMapTy = gtl::FlatMap; +using PredicateMapTy = absl::flat_hash_map; Status ComputePredicates(const Graph& graph, PredicateMapTy* out_predicate_map); // Returns a map describing the predicate each Tensor was mapped to. For @@ -32,7 +32,7 @@ Status ComputePredicates(const Graph& graph, PredicateMapTy* out_predicate_map); // specified in `reverse_post_order` which must be a valid RPO for the graph // minus NextIteration->Merge edges. Status ComputePredicates(const Graph& graph, - gtl::ArraySlice reverse_post_order, + absl::Span reverse_post_order, PredicateMapTy* out_predicate_map); } // namespace deadness_analysis_internal } // namespace tensorflow diff --git a/tensorflow/compiler/jit/deadness_analysis_test.cc b/tensorflow/compiler/jit/deadness_analysis_test.cc index 28a56044d5e3795fc3ecf5d1092491b87cb90f01..617e31488c7daeb714c0ff7056b786e4eaf7873f 100644 --- a/tensorflow/compiler/jit/deadness_analysis_test.cc +++ b/tensorflow/compiler/jit/deadness_analysis_test.cc @@ -384,10 +384,31 @@ TEST(DeadnessAnalysisTest, OrOfAnd) { EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add2.node())); } -TEST(DeadnessAnalysisTest, NEGATIVE_AndOrDistributive) { - // This demonstrates one of the weaknesses in the current approach -- since we - // only do some basic simplifications we can't see that "(A|B)&C" == - // "(A&C)|(B&C)". +TEST(DeadnessAnalysisTest, AndOrDistributiveSimplified) { + // (*A | (~*A & ((~*B & ~*A) | (~*A & *B)))) == #true + Scope root = Scope::NewRootScope().ExitOnError(); + + ops::Switch sw_0 = CreateSwitch(root, "A"); + ops::Switch sw_1 = CreateSwitch(root, "B"); + Output add0 = + ops::Add(root.WithOpName("and0"), sw_0.output_false, sw_1.output_true); + Output add1 = + ops::Add(root.WithOpName("and1"), sw_0.output_false, sw_1.output_false); + ops::Merge or2(root.WithOpName("or2"), {add0, add1}); + Output add3 = + ops::Add(root.WithOpName("and3"), or2.output, sw_0.output_false); + ops::Merge or4(root.WithOpName("or4"), {add3, sw_0.output_true}); + + std::unique_ptr result; + TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result)); + + PredicateMapTy predicate_map; + TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map)); + EXPECT_EQ(predicate_map[ControlOutputFor(or4.output)], "#true"); +} + +TEST(DeadnessAnalysisTest, AndOrDistributive) { + // (A|B)&C == (A&C)|(B&C) Scope root = Scope::NewRootScope().ExitOnError(); ops::Switch sw_0 = CreateSwitch(root, "0"); @@ -408,7 +429,7 @@ TEST(DeadnessAnalysisTest, NEGATIVE_AndOrDistributive) { std::unique_ptr result; TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result)); - EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add2.node())); + EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add3.node())); } TEST(DeadnessAnalysisTest, Ternary) { diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc index 2788102620546d8eab657c519f078c5b03e265cc..da030b3bcc7aacae2306bec30f4b8927aa042d7c 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc @@ -22,6 +22,9 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_set.h" +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/jit/graphcycles/graphcycles.h" #include "tensorflow/compiler/jit/mark_for_compilation_pass.h" #include "tensorflow/compiler/jit/shape_inference_helpers.h" @@ -42,10 +45,8 @@ limitations under the License. #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/graph/tensor_id.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/hash/hash.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/public/session_options.h" #include "tensorflow/core/public/version.h" #include "tensorflow/core/util/device_name_utils.h" @@ -58,10 +59,27 @@ const char* const kXlaNumResourceArgsAttr = "_XlaNumResourceArgs"; const char* const kXlaHostTransferSequencerAttr = "_xla_host_transfer_sequencer"; +void SortControlInputs(GraphDef* gdef) { + int64 num_nodes = gdef->node_size(); + for (int64 i = 0; i < num_nodes; ++i) { + NodeDef* node = gdef->mutable_node(i); + // Stable sort control inputs and leave the order of data inputs unchanged. + std::stable_sort(node->mutable_input()->begin(), + node->mutable_input()->end(), + [](const string& a, const string& b) { + bool a_is_control = absl::StartsWith(a, "^"); + bool b_is_control = absl::StartsWith(b, "^"); + return (!a_is_control && b_is_control) || + (a_is_control && b_is_control && a < b); + }); + } +} + namespace { bool AreAllParentsGuaranteedConst( - const Node& n, const gtl::FlatSet& runtime_const_nodes) { + const Node& n, + const absl::flat_hash_set& runtime_const_nodes) { if (n.type_string() == "GuaranteeConst") { // If the current node is itself a cast-to-const, no need // to look at the incoming edges. @@ -84,7 +102,7 @@ bool AreAllParentsGuaranteedConst( void MarkGuaranteedConstants( const Graph& graph, const std::vector>& src_arg_pairs) { - gtl::FlatSet guaranteed_const_nodes; + absl::flat_hash_set guaranteed_const_nodes; std::vector srcs; srcs.reserve(src_arg_pairs.size()); for (const auto& src_arg : src_arg_pairs) { @@ -731,6 +749,12 @@ Node* Encapsulator::Subgraph::MakeNodeImage(const Graph* graph_in, Node* node) { graph_->set_versions(graph_in->versions()); } + // TODO(b/116981129): Enhance how the device for the encapsulated subgraph is + // determined. In case of hard placement, ensure all the encapsulated nodes + // have the same requested device, which in turn will be the requested device + // for the entire encapsulated subgraph. In case of soft placement, use a + // deterministic approach to fill in the requested device. Handle co-location + // constraints similarly if they exist. if (device_.empty()) { device_ = node->assigned_device_name().empty() ? node->requested_device() @@ -755,7 +779,7 @@ Status Encapsulator::Subgraph::RecordArg( if (inserted) { NodeDef arg_def; NodeDefBuilder builder( - strings::StrCat(src_node->name(), "_", src_slot, "_arg"), kArgOp); + absl::StrCat(src_node->name(), "_", src_slot, "_arg"), kArgOp); DataType dtype = edge->dst()->input_type(edge->dst_input()); builder.Attr("T", dtype); builder.Attr("index", arg_index); @@ -790,7 +814,7 @@ Status Encapsulator::Subgraph::RecordResult( if (inserted) { NodeDef ret_def; NodeDefBuilder builder( - strings::StrCat(src_node->name(), "_", src_slot, "_retval"), kRetValOp); + absl::StrCat(src_node->name(), "_", src_slot, "_retval"), kRetValOp); DataType dtype = src_node->output_type(src_slot); builder.Attr("T", dtype); builder.Attr("index", ret_index); @@ -950,16 +974,15 @@ Status Encapsulator::Subgraph::AddHostComputes( } NodeDef host_compute_def; - NodeDefBuilder builder(strings::StrCat("outside_compilation_", - oc_subgraph_name, "_host_compute"), + NodeDefBuilder builder(absl::StrCat("outside_compilation_", + oc_subgraph_name, "_host_compute"), kHostComputeOp); builder.Input(inputs); builder.Attr("Tinputs", input_dtypes); builder.Attr("Toutputs", output_dtypes); builder.Attr("ancestors", host_compute_ancestors); - builder.Attr("key", - strings::StrCat("host_compute_channel_", subgraph_name, "_", - oc_subgraph_name)); + builder.Attr("key", absl::StrCat("host_compute_channel_", subgraph_name, + "_", oc_subgraph_name)); builder.Attr("_outside_compilation_subgraph", oc_subgraph_name); Status s = builder.Finalize(&host_compute_def); if (!s.ok()) return s; @@ -1017,8 +1040,7 @@ Status Encapsulator::Subgraph::MakeSequencingNode(const string& subgraph_name, Graph* graph_out) { if (sequencer_ == nullptr) { NodeDef seq_def; - NodeDefBuilder builder(strings::StrCat(subgraph_name, "_sequencer"), - "NoOp"); + NodeDefBuilder builder(absl::StrCat(subgraph_name, "_sequencer"), "NoOp"); builder.Attr(kXlaHostTransferSequencerAttr, subgraph_name); builder.Device(device_); Status s = builder.Finalize(&seq_def); @@ -1087,14 +1109,17 @@ Status Encapsulator::Subgraph::BuildFunctionDef( function_def_name_ = name; FunctionDef fdef; + // Verify that the graph has well-formed control flow structure. + std::vector dummy; + TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph_.get(), &dummy)); TF_RETURN_IF_ERROR(GraphToFunctionDef(*graph_, name, &fdef)); if (VLOG_IS_ON(1)) { VLOG(2) << "Build function def " << name; - dump_graph::DumpGraphToFile( - strings::StrCat("encapsulate_fdef_graph_", name), *graph_, library); - dump_graph::DumpFunctionDefToFile( - strings::StrCat("encapsulate_fdef_", name), fdef); + dump_graph::DumpGraphToFile(absl::StrCat("encapsulate_fdef_graph_", name), + *graph_, library); + dump_graph::DumpFunctionDefToFile(absl::StrCat("encapsulate_fdef_", name), + fdef); } if (!reuse_existing_functions || library->Find(name) == nullptr) { @@ -1130,8 +1155,8 @@ Status Encapsulator::Subgraph::AddShapeInferenceInfo( host_compute->AddAttr("shapes", shapes); } else { string inference_graph_name = - strings::StrCat("_outside_compilation_shape_inference_", subgraph_name, - "_", outside_compilation_subgraph_name); + absl::StrCat("_outside_compilation_shape_inference_", subgraph_name, + "_", outside_compilation_subgraph_name); FunctionDef fdef; TF_RETURN_IF_ERROR( GraphToFunctionDef(*inference_graph, inference_graph_name, &fdef)); @@ -1155,10 +1180,10 @@ Status Encapsulator::Subgraph::ReplaceFunctionDef( if (VLOG_IS_ON(1)) { VLOG(2) << "Replace function def " << name; dump_graph::DumpGraphToFile( - strings::StrCat("replace_encapsulate_fdef_graph_", name), *graph_, + absl::StrCat("replace_encapsulate_fdef_graph_", name), *graph_, library); dump_graph::DumpFunctionDefToFile( - strings::StrCat("replace_encapsulate_fdef_", name), fdef); + absl::StrCat("replace_encapsulate_fdef_", name), fdef); } TF_RETURN_IF_ERROR(library->ReplaceFunction(name, fdef)); @@ -1186,8 +1211,7 @@ Status Encapsulator::Subgraph::AddHostComputeKeyPlaceholder( GraphDefBuilder::Options options(graph_out, /*status=*/nullptr); NodeDef key_def; NodeDefBuilder builder( - strings::StrCat(call_node_def_.name(), "_key_placeholder"), - "Placeholder"); + absl::StrCat(call_node_def_.name(), "_key_placeholder"), "Placeholder"); builder.Attr("dtype", DT_STRING); builder.Attr("shape", shape_proto); builder.Attr("_host_compute_call_node", call_node_def_.name()); @@ -1221,16 +1245,16 @@ Status Encapsulator::Subgraph::AddRecvAtHostNode( } NodeDef recv_def; - NodeDefBuilder builder(strings::StrCat("outside_compilation_", subgraph_name, - "_", oc_subgraph_name, "_recv"), + NodeDefBuilder builder(absl::StrCat("outside_compilation_", subgraph_name, + "_", oc_subgraph_name, "_recv"), kRecvAtHostOp); builder.Device(device_); builder.Attr("Toutputs", dtypes); // The correct device_ordinal will be inserted during replication in a // subsequent rewrite. builder.Attr("device_ordinal", 0); - builder.Attr("key", strings::StrCat("host_compute_channel_", subgraph_name, - "_", oc_subgraph_name)); + builder.Attr("key", absl::StrCat("host_compute_channel_", subgraph_name, "_", + oc_subgraph_name)); builder.Attr(group_attribute, subgraph_name); builder.Attr(outside_compilation_attribute, oc_subgraph_name); builder.Input(host_compute_key_placeholder_->name(), 0, DT_STRING); @@ -1276,13 +1300,13 @@ Status Encapsulator::Subgraph::AddSendFromHostNode( } NodeDef send_def; - NodeDefBuilder builder(strings::StrCat("outside_compilation_", subgraph_name, - "_", oc_subgraph_name, "_send"), + NodeDefBuilder builder(absl::StrCat("outside_compilation_", subgraph_name, + "_", oc_subgraph_name, "_send"), kSendFromHostOp); builder.Device(device_); builder.Attr("Tinputs", dtypes); - builder.Attr("key", strings::StrCat("host_compute_channel_", subgraph_name, - "_", oc_subgraph_name)); + builder.Attr("key", absl::StrCat("host_compute_channel_", subgraph_name, "_", + oc_subgraph_name)); // The correct device_ordinal will be inserted during replication in a // subsequent rewrite. builder.Attr("device_ordinal", 0); @@ -1343,28 +1367,31 @@ void Encapsulator::Subgraph::GetOutsideCompilationSubgraphNames( Status Encapsulator::GetFunctionNameAttr( Node const* node, string* attr, string* outside_compilation_attr) const { - Status s = GetNodeAttr(node->attrs(), group_attribute_, attr); - if (s.code() == error::Code::NOT_FOUND) { - // Return empty attr if there's no group_attribute. - attr->clear(); - } else { - TF_RETURN_IF_ERROR(s); - } - bool has_group_attr = s.ok(); - s = GetNodeAttr(node->attrs(), outside_compilation_attribute_, - outside_compilation_attr); - if (s.code() == error::Code::NOT_FOUND) { - // Return empty attr if there's no outside_compilation attribute. - outside_compilation_attr->clear(); - } else { - TF_RETURN_IF_ERROR(s); - if (!has_group_attr) { - return errors::InvalidArgument( - "Node ", node->name(), " has ", outside_compilation_attribute_, - " attribute but no ", group_attribute_, " attribute."); + AttrSlice attrs = node->attrs(); + attr->clear(); + outside_compilation_attr->clear(); + bool found_group_attribute = false; + bool found_outside_compilation_attribute = false; + for (const auto& node_attr : attrs) { + if (node_attr.first == group_attribute_) { + TF_RETURN_IF_ERROR(AttrValueHasType(node_attr.second, "string")); + *attr = node_attr.second.s(); + found_group_attribute = true; + } else if (node_attr.first == outside_compilation_attribute_) { + TF_RETURN_IF_ERROR(AttrValueHasType(node_attr.second, "string")); + *outside_compilation_attr = node_attr.second.s(); + found_outside_compilation_attribute = true; } + if (found_group_attribute && found_outside_compilation_attribute) break; + } + + if (found_outside_compilation_attribute && !found_group_attribute) { + return errors::InvalidArgument( + "Node ", node->name(), " has ", outside_compilation_attribute_, + " attribute but no ", group_attribute_, " attribute."); + } else { + return Status::OK(); } - return Status::OK(); } bool IsInSubgraph(const string& func_id, const string& outside_compilation_id) { @@ -1507,16 +1534,13 @@ Status Encapsulator::SplitIntoSubgraphs(FunctionLibraryDefinition* library) { for (auto& entry : subgraphs_) { Subgraph& subgraph = entry.second; FixupSourceAndSinkEdges(subgraph.GetGraph()); - // Verify that the graph has well-formed control flow structure. - std::vector dummy; - TF_RETURN_IF_ERROR(BuildControlFlowInfo(subgraph.GetGraph(), &dummy)); } if (VLOG_IS_ON(1)) { // Dump subgraphs. for (auto& entry : subgraphs_) { dump_graph::DumpGraphToFile( - strings::StrCat("encapsulate_subgraphs_subgraph_", entry.first), + absl::StrCat("encapsulate_subgraphs_subgraph_", entry.first), *entry.second.GetGraph(), library); } } @@ -2052,7 +2076,7 @@ struct PathDetails { struct SubgraphAndClusterHash { inline std::size_t operator()(const SubgraphAndCluster& v) const { return hash()( - strings::StrCat(v.subgraph, v.outside_compilation_cluster)); + absl::StrCat(v.subgraph, v.outside_compilation_cluster)); } }; diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h index 926589546fec72048485d30966f31b24e44b1245..90354a801afb26b003e00c4529069fdc61bbca32 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h @@ -102,6 +102,12 @@ extern const char* const kXlaNumConstantArgsAttr; // Name of the attribute containing the number of resource variable arguments. extern const char* const kXlaNumResourceArgsAttr; +// Sorts each node's control inputs by their names. This guarantees that for two +// structually equivalent GraphDefs, we get the same traversal ordering on +// node's control input fields. +// TODO(hpucha): Move the utilities to a more appropriate place. +void SortControlInputs(GraphDef* gdef); + class EncapsulateSubgraphsPass : public GraphOptimizationPass { public: Status Run(const GraphOptimizationPassOptions& options) override; diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc index b3600fc48b9daa0e901e2b01cdc121aef0a1e8af..49958093b8dcf35e8adcdfd2f7dfce8558d5db6f 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include #include +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" #include "absl/strings/match.h" @@ -48,7 +49,7 @@ Status AddGraphDefToFunctionLibrary(const GraphDefBuilder& graphdef_builder, FunctionDef* fdef = library->add_function(); TF_RETURN_IF_ERROR(GraphToFunctionDef( *graph, - strings::StrCat("_outside_compilation_shape_inference_", name_suffix), + absl::StrCat("_outside_compilation_shape_inference_", name_suffix), fdef)); return Status::OK(); } @@ -65,18 +66,18 @@ bool EqualProtoMap(const ::tensorflow::protobuf::Map& a, const auto iter = b.find(elt_a.first); if (iter == b.end()) { if (diff) { - *diff = strings::StrCat( - map_name, " expected: contains element with key '", - key_to_string(elt_a.first), "' got: map has no such element"); + *diff = absl::StrCat(map_name, " expected: contains element with key '", + key_to_string(elt_a.first), + "' got: map has no such element"); } return false; } if (!compare(elt_a.first, elt_a.second, iter->second)) { if (diff) { - *diff = strings::StrCat(map_name, " expected: element with key '", - key_to_string(elt_a.first), "' has value '", - value_to_string(elt_a.second), "' got: '", - value_to_string(iter->second), "'"); + *diff = absl::StrCat(map_name, " expected: element with key '", + key_to_string(elt_a.first), "' has value '", + value_to_string(elt_a.second), "' got: '", + value_to_string(iter->second), "'"); } return false; } @@ -85,9 +86,9 @@ bool EqualProtoMap(const ::tensorflow::protobuf::Map& a, const auto iter = a.find(elt_b.first); if (iter == a.end()) { if (diff) { - *diff = strings::StrCat(map_name, " got: contains element with key '", - key_to_string(elt_b.first), - "' expected: map has no such element"); + *diff = absl::StrCat(map_name, " got: contains element with key '", + key_to_string(elt_b.first), + "' expected: map has no such element"); } return false; } @@ -99,25 +100,25 @@ bool EqualFunctionNodeDef(const NodeDef& a, const NodeDef& b, const string& diff_preamble, string* diff) { if (a.op() != b.op()) { if (diff) { - *diff = strings::StrCat(diff_preamble, " mismatch for node ", a.name(), - ", expected op '", a.op(), "' got '", b.op()); + *diff = absl::StrCat(diff_preamble, " mismatch for node ", a.name(), + ", expected op '", a.op(), "' got '", b.op()); } return false; } if (a.device() != b.device()) { if (diff) { - *diff = strings::StrCat(diff_preamble, " mismatch for node ", a.name(), - ", expected device '", a.device(), "' got '", - b.device()); + *diff = absl::StrCat(diff_preamble, " mismatch for node ", a.name(), + ", expected device '", a.device(), "' got '", + b.device()); } return false; } if (a.input_size() != b.input_size()) { if (diff) { - *diff = strings::StrCat(diff_preamble, " mismatch for node ", a.name(), - ", expected ", a.input_size(), " inputs got ", - b.input_size(), " expected:\n", a.DebugString(), - "\ngot:\n", b.DebugString()); + *diff = absl::StrCat(diff_preamble, " mismatch for node ", a.name(), + ", expected ", a.input_size(), " inputs got ", + b.input_size(), " expected:\n", a.DebugString(), + "\ngot:\n", b.DebugString()); } return false; } @@ -127,10 +128,10 @@ bool EqualFunctionNodeDef(const NodeDef& a, const NodeDef& b, if (absl::StartsWith(a.input(i), "^")) { if (!absl::StartsWith(b.input(i), "^")) { if (diff) { - *diff = strings::StrCat( - diff_preamble, " mismatch for node ", a.name(), " input ", i, - ", expected control input ", a.input(i), " got ", b.input(i), - " expected:\n", a.DebugString(), "\ngot:\n", b.DebugString()); + *diff = absl::StrCat(diff_preamble, " mismatch for node ", a.name(), + " input ", i, ", expected control input ", + a.input(i), " got ", b.input(i), " expected:\n", + a.DebugString(), "\ngot:\n", b.DebugString()); } return false; } @@ -138,19 +139,19 @@ bool EqualFunctionNodeDef(const NodeDef& a, const NodeDef& b, control_input_b.insert(b.input(i)); } else if (a.input(i) != b.input(i)) { if (diff) { - *diff = strings::StrCat(diff_preamble, " mismatch for node ", a.name(), - " input ", i, ", expected ", a.input(i), - " got ", b.input(i), " expected:\n", - a.DebugString(), "\ngot:\n", b.DebugString()); + *diff = absl::StrCat(diff_preamble, " mismatch for node ", a.name(), + " input ", i, ", expected ", a.input(i), " got ", + b.input(i), " expected:\n", a.DebugString(), + "\ngot:\n", b.DebugString()); } return false; } } if (control_input_a != control_input_b) { if (diff) { - *diff = strings::StrCat(diff_preamble, " mismatch for node ", a.name(), - " control inputs differ expected:\n", - a.DebugString(), "\ngot:\n", b.DebugString()); + *diff = absl::StrCat(diff_preamble, " mismatch for node ", a.name(), + " control inputs differ expected:\n", + a.DebugString(), "\ngot:\n", b.DebugString()); } return false; } @@ -170,18 +171,17 @@ bool EqualFunctionNodeDef(const NodeDef& a, const NodeDef& b, return av.DebugString() == bv.DebugString(); } }, - strings::StrCat(diff_preamble, " attr mismatch for node ", a.name()), - diff); + absl::StrCat(diff_preamble, " attr mismatch for node ", a.name()), diff); } bool EqualFunctionDef(const FunctionDef& a, const FunctionDef& b, string* diff) { if (a.signature().DebugString() != b.signature().DebugString()) { if (diff) { - *diff = strings::StrCat("Signature mismatch for function ", - a.signature().name(), ", expected:\n", - a.signature().DebugString(), "\ngot:\n", - b.signature().DebugString()); + *diff = + absl::StrCat("Signature mismatch for function ", a.signature().name(), + ", expected:\n", a.signature().DebugString(), "\ngot:\n", + b.signature().DebugString()); } return false; } @@ -191,7 +191,7 @@ bool EqualFunctionDef(const FunctionDef& a, const FunctionDef& b, [](const string& key, const AttrValue& av, const AttrValue& bv) { return av.DebugString() == bv.DebugString(); }, - strings::StrCat("attr mismatch for function ", a.signature().name()), + absl::StrCat("attr mismatch for function ", a.signature().name()), diff)) { return false; } @@ -201,7 +201,7 @@ bool EqualFunctionDef(const FunctionDef& a, const FunctionDef& b, [](const string& key, const string& av, const string& bv) { return av == bv; }, - strings::StrCat("ret mismatch for function ", a.signature().name()), + absl::StrCat("ret mismatch for function ", a.signature().name()), diff)) { return false; } @@ -211,7 +211,7 @@ bool EqualFunctionDef(const FunctionDef& a, const FunctionDef& b, if (a.node_def(i).name() == b.node_def(j).name()) { if (!EqualFunctionNodeDef( a.node_def(i), b.node_def(j), - strings::StrCat("Function ", a.signature().name()), diff)) { + absl::StrCat("Function ", a.signature().name()), diff)) { return false; } found = true; @@ -220,9 +220,9 @@ bool EqualFunctionDef(const FunctionDef& a, const FunctionDef& b, } if (!found) { if (diff) { - *diff = strings::StrCat("Function ", a.signature().name(), - ", expected: has node '", a.node_def(i).name(), - "' got: no node of that name"); + *diff = absl::StrCat("Function ", a.signature().name(), + ", expected: has node '", a.node_def(i).name(), + "' got: no node of that name"); } return false; } @@ -237,9 +237,9 @@ bool EqualFunctionDef(const FunctionDef& a, const FunctionDef& b, } if (!found) { if (diff) { - *diff = strings::StrCat("Function ", a.signature().name(), - ", got: has node '", b.node_def(i).name(), - "' expected: no node of that name"); + *diff = absl::StrCat("Function ", a.signature().name(), + ", got: has node '", b.node_def(i).name(), + "' expected: no node of that name"); } return false; } @@ -258,8 +258,8 @@ bool EqualFunctionDefLibrary(const FunctionDefLibrary& expected, auto it = actual_index.find(expected_function.signature().name()); if (it == actual_index.end()) { if (diff) { - *diff = strings::StrCat("Did not find expected function '", - expected_function.signature().name(), "'"); + *diff = absl::StrCat("Did not find expected function '", + expected_function.signature().name(), "'"); } return false; } @@ -269,9 +269,9 @@ bool EqualFunctionDefLibrary(const FunctionDefLibrary& expected, if (!actual_index.empty()) { if (diff != nullptr) { - *diff = strings::StrCat("Found unexpected function '", - actual_index.begin()->second->signature().name(), - "'"); + *diff = + absl::StrCat("Found unexpected function '", + actual_index.begin()->second->signature().name(), "'"); } return false; } @@ -379,7 +379,7 @@ Node* InputShaped(const GraphDefBuilder::Options& opts) { return ops::SourceOp("InputTestShaped", opts); } -Node* KnownShapeBase(DataType dtype, const gtl::ArraySlice& shape, +Node* KnownShapeBase(DataType dtype, absl::Span shape, const GraphDefBuilder::Options& opts) { if (opts.HaveError()) return nullptr; NodeBuilder node_builder(opts.GetNameForOp("Const"), "Const", @@ -394,7 +394,7 @@ Node* KnownShapeBase(DataType dtype, const gtl::ArraySlice& shape, .FinalizeBuilder(&node_builder); } -Node* KnownShape(const gtl::ArraySlice& shape, +Node* KnownShape(absl::Span shape, const GraphDefBuilder::Options& opts) { return KnownShapeBase(DT_FLOAT, shape, opts); } @@ -417,14 +417,12 @@ Node* KeyPlaceholder(const string& call_node, } Node* RecvAtHost(ops::NodeOut key_input, const string& cluster, - const string& oc_cluster, - const gtl::ArraySlice& dtypes, + const string& oc_cluster, absl::Span dtypes, const GraphDefBuilder::Options& opts) { if (opts.HaveError()) return nullptr; - string key = - strings::StrCat("host_compute_channel_", cluster, "_", oc_cluster); - string name = strings::StrCat("outside_compilation_", cluster, "_", - oc_cluster, "_recv"); + string key = absl::StrCat("host_compute_channel_", cluster, "_", oc_cluster); + string name = + absl::StrCat("outside_compilation_", cluster, "_", oc_cluster, "_recv"); NodeBuilder node_builder(opts.WithName(name).GetNameForOp("_XlaRecvAtHost"), "_XlaRecvAtHost", opts.op_registry()); node_builder.Input(std::move(key_input)); @@ -441,10 +439,9 @@ Node* SendFromHost(ops::NodeOut key_input, const string& cluster, const std::vector& inputs, const GraphDefBuilder::Options& opts) { if (opts.HaveError()) return nullptr; - string key = - strings::StrCat("host_compute_channel_", cluster, "_", oc_cluster); - string name = strings::StrCat("outside_compilation_", cluster, "_", - oc_cluster, "_send"); + string key = absl::StrCat("host_compute_channel_", cluster, "_", oc_cluster); + string name = + absl::StrCat("outside_compilation_", cluster, "_", oc_cluster, "_send"); NodeBuilder node_builder(opts.WithName(name).GetNameForOp("_XlaSendFromHost"), "_XlaSendFromHost", opts.op_registry()); node_builder.Input(inputs); @@ -683,8 +680,8 @@ std::vector> GraphEdges(const Graph& graph) { for (const Edge* edge : graph.edges()) { if (edge->src()->IsSource() || edge->dst()->IsSink()) continue; edges.emplace_back( - strings::StrCat(edge->src()->name(), ":", edge->src_output()), - strings::StrCat(edge->dst()->name(), ":", edge->dst_input())); + absl::StrCat(edge->src()->name(), ":", edge->src_output()), + absl::StrCat(edge->dst()->name(), ":", edge->dst_input())); } std::sort(edges.begin(), edges.end()); return edges; @@ -892,13 +889,13 @@ TEST(EncapsulateSubgraphsTest, OneFunctionOneOutside) { {{"outside_compilation_O1_host_compute"}, "XlaHostCompute", {"C:o:0", "c:o:0"}, - {{"Tinputs", gtl::ArraySlice({DT_FLOAT, DT_FLOAT})}, - {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, - {"ancestors", gtl::ArraySlice({})}, + {{"Tinputs", absl::Span({DT_FLOAT, DT_FLOAT})}, + {"Toutputs", absl::Span({DT_FLOAT})}, + {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O1"}, {"shape_inference_graph", "_outside_compilation_shape_inference_F1_O1"}, - {"shapes", gtl::ArraySlice({})}, + {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O1"}}, {"c"}}, }, @@ -1038,26 +1035,26 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) { {{"outside_compilation_O2_host_compute"}, "XlaHostCompute", {"F:o:0", "D:o:0"}, - {{"Tinputs", gtl::ArraySlice({DT_FLOAT, DT_FLOAT})}, - {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, + {{"Tinputs", absl::Span({DT_FLOAT, DT_FLOAT})}, + {"Toutputs", absl::Span({DT_FLOAT})}, {"ancestors", - gtl::ArraySlice({"outside_compilation_O1_host_compute"})}, + absl::Span({"outside_compilation_O1_host_compute"})}, {"key", "host_compute_channel_F1_O2"}, {"shape_inference_graph", "_outside_compilation_shape_inference_F1_O2"}, - {"shapes", gtl::ArraySlice({})}, + {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O2"}}, {"F", "outside_compilation_O1_host_compute"}}, {{"outside_compilation_O1_host_compute"}, "XlaHostCompute", {"C:o:0", "D:o:0"}, - {{"Tinputs", gtl::ArraySlice({DT_FLOAT, DT_FLOAT})}, - {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, - {"ancestors", gtl::ArraySlice({})}, + {{"Tinputs", absl::Span({DT_FLOAT, DT_FLOAT})}, + {"Toutputs", absl::Span({DT_FLOAT})}, + {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O1"}, {"shape_inference_graph", "_outside_compilation_shape_inference_F1_O1"}, - {"shapes", gtl::ArraySlice({})}, + {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O1"}}, {"D"}}, }, @@ -1190,13 +1187,13 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) { {{"outside_compilation_O1_host_compute"}, "XlaHostCompute", {"C:o:0", "D:o:0"}, - {{"Tinputs", gtl::ArraySlice({DT_FLOAT, DT_FLOAT})}, - {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, - {"ancestors", gtl::ArraySlice({})}, + {{"Tinputs", absl::Span({DT_FLOAT, DT_FLOAT})}, + {"Toutputs", absl::Span({DT_FLOAT})}, + {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O1"}, {"shape_inference_graph", "_outside_compilation_shape_inference_F1_O1"}, - {"shapes", gtl::ArraySlice({})}, + {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O1"}}, {"D"}}, }, @@ -1213,13 +1210,13 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) { {{"outside_compilation_O1_host_compute"}, "XlaHostCompute", {"G:o:0"}, - {{"Tinputs", gtl::ArraySlice({DT_FLOAT})}, - {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, - {"ancestors", gtl::ArraySlice({})}, + {{"Tinputs", absl::Span({DT_FLOAT})}, + {"Toutputs", absl::Span({DT_FLOAT})}, + {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F2_O1"}, {"shape_inference_graph", ""}, {"shapes", - gtl::ArraySlice({shape_proto_expected})}, + absl::Span({shape_proto_expected})}, {"_outside_compilation_subgraph", "O1"}}}, }, {{"g_0_retval", "G:o:0"}, {"i_0_retval", "I:o:0"}}); @@ -1364,13 +1361,13 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutsideDependencyFromOutside) { {{"outside_compilation_O1_host_compute"}, "XlaHostCompute", {"C:o:0", "D:o:0"}, - {{"Tinputs", gtl::ArraySlice({DT_FLOAT, DT_FLOAT})}, - {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, - {"ancestors", gtl::ArraySlice({})}, + {{"Tinputs", absl::Span({DT_FLOAT, DT_FLOAT})}, + {"Toutputs", absl::Span({DT_FLOAT})}, + {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O1"}, {"shape_inference_graph", "_outside_compilation_shape_inference_F1_O1"}, - {"shapes", gtl::ArraySlice({})}, + {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O1"}}, {"D"}}, }, @@ -1386,13 +1383,13 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutsideDependencyFromOutside) { {{"outside_compilation_O1_host_compute"}, "XlaHostCompute", {"G:o:0"}, - {{"Tinputs", gtl::ArraySlice({DT_FLOAT})}, - {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, - {"ancestors", gtl::ArraySlice({})}, + {{"Tinputs", absl::Span({DT_FLOAT})}, + {"Toutputs", absl::Span({DT_FLOAT})}, + {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F2_O1"}, {"shape_inference_graph", "_outside_compilation_shape_inference_F2_O1"}, - {"shapes", gtl::ArraySlice({})}, + {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O1"}}}, }, {{"i_0_retval", "I:o:0"}}); @@ -1495,13 +1492,13 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputs) { {{"outside_compilation_O1_host_compute"}, "XlaHostCompute", {}, - {{"Tinputs", gtl::ArraySlice({})}, - {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, - {"ancestors", gtl::ArraySlice({})}, + {{"Tinputs", absl::Span({})}, + {"Toutputs", absl::Span({DT_FLOAT})}, + {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O1"}, {"shape_inference_graph", ""}, {"shapes", - gtl::ArraySlice({shape_proto_expected})}, + absl::Span({shape_proto_expected})}, {"_outside_compilation_subgraph", "O1"}}}, }, {{"f_0_retval", "F:o:0"}}); @@ -1579,13 +1576,13 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlInput) { {{"outside_compilation_O1_host_compute"}, "XlaHostCompute", {}, - {{"Tinputs", gtl::ArraySlice({})}, - {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, - {"ancestors", gtl::ArraySlice({})}, + {{"Tinputs", absl::Span({})}, + {"Toutputs", absl::Span({DT_FLOAT})}, + {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O1"}, {"shape_inference_graph", ""}, {"shapes", - gtl::ArraySlice({shape_proto_expected})}, + absl::Span({shape_proto_expected})}, {"_outside_compilation_subgraph", "O1"}}, {"D"}}, }, @@ -1661,12 +1658,12 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoOutputs) { {{"outside_compilation_O1_host_compute"}, "XlaHostCompute", {"D:o:0"}, - {{"Tinputs", gtl::ArraySlice({DT_FLOAT})}, - {"Toutputs", gtl::ArraySlice({})}, - {"ancestors", gtl::ArraySlice({})}, + {{"Tinputs", absl::Span({DT_FLOAT})}, + {"Toutputs", absl::Span({})}, + {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O1"}, {"shape_inference_graph", ""}, - {"shapes", gtl::ArraySlice({})}, + {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O1"}}}, }, {{"f_0_retval", "F:o:0"}}); @@ -1742,12 +1739,12 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlOutput) { {{"outside_compilation_O1_host_compute"}, "XlaHostCompute", {"D:o:0"}, - {{"Tinputs", gtl::ArraySlice({DT_FLOAT})}, - {"Toutputs", gtl::ArraySlice({})}, - {"ancestors", gtl::ArraySlice({})}, + {{"Tinputs", absl::Span({DT_FLOAT})}, + {"Toutputs", absl::Span({})}, + {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O1"}, {"shape_inference_graph", ""}, - {"shapes", gtl::ArraySlice({})}, + {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O1"}}}, }, {{"f_0_retval", "F:o:0"}}); @@ -1846,13 +1843,13 @@ TEST(EncapsulateSubgraphsTest, {{"outside_compilation_O2_host_compute"}, "XlaHostCompute", {"F:o:0"}, - {{"Tinputs", gtl::ArraySlice({DT_FLOAT})}, - {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, - {"ancestors", gtl::ArraySlice({})}, + {{"Tinputs", absl::Span({DT_FLOAT})}, + {"Toutputs", absl::Span({DT_FLOAT})}, + {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O2"}, {"shape_inference_graph", "_outside_compilation_shape_inference_F1_O2"}, - {"shapes", gtl::ArraySlice({})}, + {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O2"}}}, }, {{"h_0_retval", "H:o:0"}}); @@ -1955,13 +1952,13 @@ TEST(EncapsulateSubgraphsTest, {{"outside_compilation_O1_host_compute"}, "XlaHostCompute", {"D:o:0"}, - {{"Tinputs", gtl::ArraySlice({DT_FLOAT})}, - {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, - {"ancestors", gtl::ArraySlice({})}, + {{"Tinputs", absl::Span({DT_FLOAT})}, + {"Toutputs", absl::Span({DT_FLOAT})}, + {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O1"}, {"shape_inference_graph", "_outside_compilation_shape_inference_F1_O1"}, - {"shapes", gtl::ArraySlice({})}, + {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O1"}}}, }, {{"h_0_retval", "H:o:0"}}); @@ -2066,37 +2063,37 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) { {{"outside_compilation_O1_host_compute"}, "XlaHostCompute", {"D:o:0"}, - {{"Tinputs", gtl::ArraySlice({DT_FLOAT})}, - {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, - {"ancestors", gtl::ArraySlice({})}, + {{"Tinputs", absl::Span({DT_FLOAT})}, + {"Toutputs", absl::Span({DT_FLOAT})}, + {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O1"}, {"shape_inference_graph", "_outside_compilation_shape_inference_F1_O1"}, - {"shapes", gtl::ArraySlice({})}, + {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O1"}}}, {{"outside_compilation_O2_host_compute"}, "XlaHostCompute", {"D:o:0"}, - {{"Tinputs", gtl::ArraySlice({DT_FLOAT})}, - {"Toutputs", gtl::ArraySlice({})}, + {{"Tinputs", absl::Span({DT_FLOAT})}, + {"Toutputs", absl::Span({})}, {"ancestors", - gtl::ArraySlice({"outside_compilation_O1_host_compute"})}, + absl::Span({"outside_compilation_O1_host_compute"})}, {"key", "host_compute_channel_F1_O2"}, {"shape_inference_graph", ""}, - {"shapes", gtl::ArraySlice({})}, + {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O2"}}, {"outside_compilation_O1_host_compute"}}, {{"outside_compilation_O3_host_compute"}, "XlaHostCompute", {"D:o:0"}, - {{"Tinputs", gtl::ArraySlice({DT_FLOAT})}, - {"Toutputs", gtl::ArraySlice({})}, + {{"Tinputs", absl::Span({DT_FLOAT})}, + {"Toutputs", absl::Span({})}, {"ancestors", - gtl::ArraySlice({"outside_compilation_O1_host_compute", - "outside_compilation_O2_host_compute"})}, + absl::Span({"outside_compilation_O1_host_compute", + "outside_compilation_O2_host_compute"})}, {"key", "host_compute_channel_F1_O3"}, {"shape_inference_graph", ""}, - {"shapes", gtl::ArraySlice({})}, + {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O3"}}, {"outside_compilation_O1_host_compute", "outside_compilation_O2_host_compute"}}}, @@ -2272,13 +2269,13 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationShapeInference) { {{"outside_compilation_O1_host_compute"}, "XlaHostCompute", {"c:o:0"}, - {{"Tinputs", gtl::ArraySlice({DT_FLOAT})}, - {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, - {"ancestors", gtl::ArraySlice({})}, + {{"Tinputs", absl::Span({DT_FLOAT})}, + {"Toutputs", absl::Span({DT_FLOAT})}, + {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O1"}, {"shape_inference_graph", "_outside_compilation_shape_inference_F1_O1"}, - {"shapes", gtl::ArraySlice({})}, + {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O1"}}, {"c"}}, }, diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..2ce6fa73fc448ca83fa392aa909cb385453eb8b6 --- /dev/null +++ b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc @@ -0,0 +1,362 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/encapsulate_xla_computations_pass.h" + +#include "absl/container/flat_hash_set.h" +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" +#include "tensorflow/compiler/tf2xla/dump_graph.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/lib/hash/hash.h" +#include "tensorflow/core/lib/strings/proto_serialization.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/fingerprint.h" + +namespace tensorflow { + +const char* const EncapsulateXlaComputationsPass::kXlaClusterAttr = + "_xla_compile_id"; + +namespace { + +const char* const kXlaClusterOutput = "XlaClusterOutput"; + +// Checks if a graph node is marked to be a guaranteed constant. +bool is_guaranteed_constant(const Node& n) { + bool guaranteed_constant = false; + if (!GetNodeAttr(n.attrs(), "_is_guaranteed_constant", &guaranteed_constant) + .ok()) { + return false; + } + return guaranteed_constant; +} + +// Finds the `index` of an _Arg or _Retval node. +Status GetIndexAttr(const Node& n, int num_args, int* index) { + TF_RETURN_IF_ERROR(GetNodeAttr(n.attrs(), "index", index)); + if (*index < 0 || *index >= num_args) { + return errors::InvalidArgument("Invalid ", n.type_string(), " number ", + *index); + } + return Status::OK(); +} + +// Returns the data type of the destination of an edge. +DataType EdgeType(const Edge* edge) { + return edge->dst()->input_type(edge->dst_input()); +} + +// Adds the control inputs of `node` to `*deps`. +void AddControlInputs(const Node& node, absl::flat_hash_set* deps) { + for (const Edge* edge : node.in_edges()) { + if (edge->IsControlEdge()) { + deps->insert(edge->src()); + } + } +} + +// Adds the control outputs of `node` to `*deps`. +void AddControlOutputs(const Node& node, absl::flat_hash_set* deps) { + for (const Edge* edge : node.out_edges()) { + if (edge->IsControlEdge()) { + deps->insert(edge->dst()); + } + } +} + +// Rewrite function to be passed to EncapsulateSubgraphsInFunctions that sorts +// the arguments into the order expected by XlaLaunch computations: +// 1) arguments +// 2) resource variable arguments +// See the documentation of EncapsulateSubgraphsInFunctions for the meaning +// of the arguments. +// +// TODO(b/113166435): Ordering constraints on XlaLaunch op can be relaxed. +Status RewriteSubgraph(const std::vector& arg_source_tensors, + std::unique_ptr* graph_ptr, + std::vector* input_permutation, + std::vector* output_permutation, + NodeDef* call_def) { + Graph* graph = graph_ptr->get(); + const int num_args = input_permutation->size(); + const int num_retvals = output_permutation->size(); + + std::vector args; + std::vector retvals; + args.reserve(num_args); + retvals.reserve(num_retvals); + for (Node* n : graph->nodes()) { + if (n->type_string() == "_Arg") { + // Check if this is a guaranteed constant. + if (is_guaranteed_constant(*n)) { + return errors::InvalidArgument( + "Guaranteed constants are not supported (", n->name(), ")"); + } + args.push_back(n); + } else if (n->type_string() == "_Retval") { + retvals.push_back(n); + } + } + + if (std::find(args.begin(), args.end(), nullptr) != args.end()) { + return errors::InvalidArgument("Missing or non-consecutive arguments"); + } + + // Reorders the arguments. + std::sort(args.begin(), args.end(), [&](Node* a, Node* b) { + // Non-resources appear before resources + bool a_is_resource = (a->output_type(0) == DT_RESOURCE); + bool b_is_resource = (b->output_type(0) == DT_RESOURCE); + // Uses the name as a tiebreaker so the output is deterministic. + StringPiece a_name(a->name()); + StringPiece b_name(b->name()); + return std::tie(a_is_resource, a_name) < std::tie(b_is_resource, b_name); + }); + + // Sorts the retvals by name so the order is deterministic. + std::sort(retvals.begin(), retvals.end(), + [](Node* a, Node* b) { return a->name() < b->name(); }); + + // Computes the permutation to produce the correct argument order, and update + // the argument indices. + int variable_start_index = num_args; + for (int i = 0; i < num_args; ++i) { + int index; + TF_RETURN_IF_ERROR(GetIndexAttr(*args[i], num_args, &index)); + if (args[i]->output_type(0) == DT_RESOURCE && + variable_start_index == num_args) { + variable_start_index = i; + } + (*input_permutation)[index] = i; + args[i]->AddAttr("index", i); + } + VLOG(4) << "variable_start_index: " << variable_start_index; + + // Computes the permutation to produce the correct retval order, and update + // the argument indices. + for (int i = 0; i < num_retvals; ++i) { + int index; + TF_RETURN_IF_ERROR(GetIndexAttr(*retvals[i], num_retvals, &index)); + (*output_permutation)[index] = i; + retvals[i]->AddAttr("index", i); + } + + AddNodeAttr(EncapsulateXlaComputationsPass::kXlaClusterAttr, call_def->name(), + call_def); + AddNodeAttr("_variable_start_index", variable_start_index, call_def); + + // Uniquify the function name. + GraphDef gdef; + graph->ToGraphDef(&gdef); + + // Before serialization, sort each node's control inputs to achieve + // determinism. Sorting control inputs could help (but not necessarily) create + // a deterministic serialization and fingerprint. Other sources of + // nondeterminism include unstable node ordering. + SortControlInputs(&gdef); + // Fingerprint the function. + // Nondeterminism in serialization would not lead to incorrect results, but + // may cause spurious cache misses. DeterministicSerialization is a + // best-effort deterministic serialization. + string serialized; + TF_RET_CHECK(SerializeToStringDeterministic(gdef, &serialized)); + uint64 fingerprint = Fingerprint64(serialized); + LOG(INFO) << "Subgraph fingerprint:" << fingerprint; + call_def->set_op(absl::StrCat(call_def->op(), "_", fingerprint)); + return Status::OK(); +} + +} // namespace + +/*static*/ Status EncapsulateXlaComputationsPass::Encapsulate( + std::unique_ptr* graph, FunctionLibraryDefinition* flib_def) { + // Check for undeclared outputs before Encapsulation, so we can give a better + // error message. + // TODO(phawkins): merge this with the encapsulation code to avoid the extra + // O(n) pass over the edges. + for (const Edge* e : (*graph)->edges()) { + if (!e->IsControlEdge() && + e->src()->attrs().Find(kXlaClusterAttr) != nullptr && + e->dst()->attrs().Find(kXlaClusterAttr) == nullptr && + e->dst()->type_string() != kXlaClusterOutput) { + return errors::InvalidArgument( + "Undeclared output of XLA computation. A common cause of this error " + "is variable initializers that depend on the XLA computation. Edge: ", + e->src()->name(), ":", e->src_output(), " -> ", e->dst()->name(), ":", + e->dst_input()); + } + } + + auto output = absl::make_unique((*graph)->op_registry()); + TF_RETURN_WITH_CONTEXT_IF_ERROR( + EncapsulateSubgraphsInFunctions( + kXlaClusterAttr, "", **graph, RewriteSubgraph, + /*reuse_existing_functions=*/true, &output, flib_def), + "EncapsulateXlaComputationsPass failed"); + graph->swap(output); + return Status::OK(); +} + +/*static*/ Status EncapsulateXlaComputationsPass::BuildXlaLaunchOps( + Graph* graph) { + // Finds all of the XlaLaunch function calls, to avoid mutating the graph + // while iterating. + std::vector launch_nodes; + for (Node* n : graph->nodes()) { + string name; + if (GetNodeAttr(n->attrs(), kXlaClusterAttr, &name).ok()) { + launch_nodes.push_back(n); + } + } + + // Replaces each launch function call together with its neighboring + // XlaClusterOutput nodes with a XlaLaunch node. + for (Node* launch : launch_nodes) { + int variable_start_index; + TF_RETURN_IF_ERROR(GetNodeAttr(launch->attrs(), "_variable_start_index", + &variable_start_index)); + + std::vector in_edges; + TF_RETURN_IF_ERROR(launch->input_edges(&in_edges)); + + const int num_inputs = in_edges.size(); + const int num_variables = num_inputs - variable_start_index; + const int num_args = variable_start_index; + + VLOG(4) << "Launch node '" << launch->name() << "'" + << " input edges: " << in_edges.size() << " num_args: " << num_args + << " num_variables: " << num_variables; + + std::vector nodes_to_remove = {launch}; + + // Data and control inputs to the new XlaLaunch node. + std::vector> data_inputs(num_inputs); + absl::flat_hash_set control_inputs; + DataTypeVector arg_types(num_args); + + AddControlInputs(*launch, &control_inputs); + + for (int i = 0; i < num_args; ++i) { + const Edge* edge = in_edges[i]; + data_inputs[i] = {edge->src(), edge->src_output()}; + arg_types[i] = EdgeType(edge); + } + + // Appends the variable inputs. + for (int i = 0; i < num_variables; ++i) { + int pos = variable_start_index + i; + const Edge* edge = in_edges[pos]; + data_inputs[pos] = {edge->src(), edge->src_output()}; + } + + // Outputs. + const int num_outputs = launch->output_types().size(); + absl::flat_hash_set control_outputs; + std::vector>> data_outputs(num_outputs); + DataTypeVector output_types(num_outputs); + + for (const Edge* le : launch->out_edges()) { + if (le->IsControlEdge()) { + control_outputs.insert(le->dst()); + } else { + TF_RET_CHECK(le->src_output() < num_outputs); + Node* output_node = le->dst(); + + TF_RET_CHECK(output_node->type_string() == kXlaClusterOutput) + << le->DebugString(); + nodes_to_remove.push_back(output_node); + + for (const Edge* oe : output_node->out_edges()) { + TF_RET_CHECK(!oe->IsControlEdge()); + data_outputs[le->src_output()].push_back( + {oe->dst(), oe->dst_input()}); + } + output_types[le->src_output()] = output_node->input_type(0); + + AddControlOutputs(*output_node, &control_outputs); + } + } + + NodeDef def; + def.set_name(launch->name()); + + // Target the XLA CPU/GPU backends. + VLOG(2) << "Replacing with XlaLaunch"; + VLOG(2) << "Device is " << launch->requested_device(); + def.set_op("XlaLaunch"); + def.set_device(launch->requested_device()); + AddNodeAttr("Tconstants", DataTypeVector{}, &def); + AddNodeAttr("Targs", arg_types, &def); + AddNodeAttr("Nresources", num_variables, &def); + AddNodeAttr("Tresults", output_types, &def); + NameAttrList function; + function.set_name(launch->type_string()); + AddNodeAttr("function", function, &def); + + for (Node* node : nodes_to_remove) { + VLOG(2) << "Deleting node " << node->DebugString(); + // Ensure that we do not attempt to add control edges to nodes that are + // deleted. + control_inputs.erase(node); + control_outputs.erase(node); + graph->RemoveNode(node); + } + + Status status; + Node* xla_launch = graph->AddNode(def, &status); + if (!status.ok()) { + return status; + } + for (int i = 0; i < data_inputs.size(); ++i) { + graph->AddEdge(data_inputs[i].first, data_inputs[i].second, xla_launch, + i); + } + for (Node* n : control_inputs) { + graph->AddControlEdge(n, xla_launch); + } + for (int i = 0; i < data_outputs.size(); ++i) { + for (const auto& successor : data_outputs[i]) { + graph->AddEdge(xla_launch, i, successor.first, successor.second); + } + } + for (Node* n : control_outputs) { + graph->AddControlEdge(xla_launch, n); + } + } + return Status::OK(); +} + +Status EncapsulateXlaComputationsPass::Run( + const GraphOptimizationPassOptions& options) { + VLOG(1) << "EncapsulateXlaComputations(): " + << dump_graph::DumpGraphToFile("encapsulate_xla_computations_before", + **options.graph, options.flib_def); + + TF_RETURN_IF_ERROR(Encapsulate(options.graph, options.flib_def)); + VLOG(1) << "EncapsulateXlaComputations() half-way: " + << dump_graph::DumpGraphToFile("encapsulate_xla_computations_halfway", + **options.graph, options.flib_def); + + TF_RETURN_IF_ERROR(BuildXlaLaunchOps(options.graph->get())); + VLOG(1) << "EncapsulateXlaComputations() finished: " + << dump_graph::DumpGraphToFile("encapsulate_xla_computations_after", + **options.graph, options.flib_def); + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.h b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..99e9dfd598f29697dd009aa32f5317ed3dc647ae --- /dev/null +++ b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.h @@ -0,0 +1,60 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ +// Rewrites computations generated by the xla.compile() Python code into +// XlaLaunch nodes. +// +// xla.compile() does two main things: +// a) marks operators that make up an XLA computation with the attribute +// _xla_compile_id=XYZ, where XYZ is a unique key. +// b) adds XlaClusterOutput nodes to represent outputs of the computation. +// These nodes are not marked with the _xla_compile_id attribute. + +#ifndef TENSORFLOW_COMPILER_JIT_ENCAPSULATE_XLA_COMPUTATIONS_PASS_H_ +#define TENSORFLOW_COMPILER_JIT_ENCAPSULATE_XLA_COMPUTATIONS_PASS_H_ + +#include "tensorflow/core/common_runtime/optimization_registry.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/platform/env.h" + + namespace tensorflow { + +// Encapsulates nodes marked with the _xla_compile_id attribute into +// XlaLaunch operators. +class EncapsulateXlaComputationsPass : public GraphOptimizationPass { + public: + static const char* const kXlaClusterAttr; // _xla_compile_id + + Status Run(const GraphOptimizationPassOptions& options) override; + + // The following methods are public only for unit tests. + + // This pass has two stages: + // a) first, we call EncapsulateSubgraphsPass to encapsulate all nodes + // marked with the same _xla_compile_id attribute into functions. These + // functions contain the computations to be passed to XlaLaunch. During + // encapsulation, we sort the arguments into the order expected by + // XlaLaunch. + static Status Encapsulate(std::unique_ptr* graph, + FunctionLibraryDefinition* flib_def); + + // b) we rewrite the function calls generated in phase (a) into XlaLaunch + // operators. We also convert the XlaClusterOutput output nodes of the + // function call into the outputs of the XlaLaunch operator. + static Status BuildXlaLaunchOps(Graph* graph); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_ENCAPSULATE_XLA_COMPUTATIONS_PASS_H_ diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc b/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..192e1c7b32467d80cef6ff61a1c7078f8dea9dfb --- /dev/null +++ b/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc @@ -0,0 +1,350 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/encapsulate_xla_computations_pass.h" + +#include "tensorflow/cc/ops/function_ops.h" +#include "tensorflow/cc/ops/resource_variable_ops.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" +#include "tensorflow/compiler/tf2xla/cc/ops/xla_jit_ops.h" +#include "tensorflow/compiler/tf2xla/test_util.h" +#include "tensorflow/core/framework/graph_to_functiondef.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/hash/hash.h" +#include "tensorflow/core/lib/strings/proto_serialization.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/util/equal_graph_def.h" +#include "tensorflow/core/util/ptr_util.h" + +namespace tensorflow { + +static std::unique_ptr MakeOuterGraph( + const FunctionLibraryDefinition& flib_def, const string& function) { + Scope scope = Scope::NewRootScope().ExitOnError(); + TF_EXPECT_OK(scope.graph()->AddFunctionLibrary(flib_def.ToProto())); + + auto a = ops::Placeholder(scope.WithOpName("A"), DT_INT32); + auto b = ops::Placeholder(scope.WithOpName("B"), DT_FLOAT); + auto c = ops::Placeholder(scope.WithOpName("C"), DT_INT32); + auto d = ops::Placeholder(scope.WithOpName("D"), DT_FLOAT); + auto u = ops::Placeholder(scope.WithOpName("U"), DT_RESOURCE); + auto v = ops::Placeholder(scope.WithOpName("V"), DT_RESOURCE); + auto w = ops::Placeholder(scope.WithOpName("W"), DT_RESOURCE); + + NodeDef def; + TF_CHECK_OK( + NodeDefBuilder("launch0", function, &flib_def) + .Input(a.node()->name(), 0, DT_INT32) + .Input(b.node()->name(), 0, DT_FLOAT) + .Input(c.node()->name(), 0, DT_INT32) + .Input(d.node()->name(), 0, DT_FLOAT) + .Input(u.node()->name(), 0, DT_RESOURCE) + .Input(v.node()->name(), 0, DT_RESOURCE) + .Input(w.node()->name(), 0, DT_RESOURCE) + .Device("/gpu:0") + .Attr(EncapsulateXlaComputationsPass::kXlaClusterAttr, "launch0") + .Attr("_variable_start_index", 4) + .Finalize(&def)); + + Status status; + Node* launch = scope.graph()->AddNode(def, &status); + TF_CHECK_OK(status); + TF_CHECK_OK(scope.DoShapeInference(launch)); + scope.graph()->AddEdge(a.node(), 0, launch, 0); + scope.graph()->AddEdge(b.node(), 0, launch, 1); + scope.graph()->AddEdge(c.node(), 0, launch, 2); + scope.graph()->AddEdge(d.node(), 0, launch, 3); + scope.graph()->AddEdge(u.node(), 0, launch, 4); + scope.graph()->AddEdge(v.node(), 0, launch, 5); + scope.graph()->AddEdge(w.node(), 0, launch, 6); + + auto out0 = + ops::XlaClusterOutput(scope.WithOpName("Out0"), Output(launch, 0)); + auto out1 = + ops::XlaClusterOutput(scope.WithOpName("Out1"), Output(launch, 1)); + auto out2 = + ops::XlaClusterOutput(scope.WithOpName("Out2"), Output(launch, 2)); + auto out3 = + ops::XlaClusterOutput(scope.WithOpName("Out3"), Output(launch, 3)); + + auto consumer0_a = ops::Identity(scope.WithOpName("consumer0_a"), out0); + auto consumer0_b = ops::Identity(scope.WithOpName("consumer0_b"), out0); + auto consumer0_c = ops::Identity(scope.WithOpName("consumer0_c"), out0); + auto consumer1 = ops::Identity(scope.WithOpName("consumer1"), out1); + auto consumer2 = ops::Identity(scope.WithOpName("consumer2"), out2); + auto consumer3 = ops::Identity(scope.WithOpName("consumer3"), out3); + + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_CHECK_OK(scope.ToGraph(graph.get())); + return graph; +} + +// Makes an encapsulate body graph for use in tests. +static std::unique_ptr MakeBodyGraph() { + Scope scope = Scope::NewRootScope().ExitOnError(); + + auto arg0 = ops::_Arg(scope.WithOpName("a_0_arg"), DT_INT32, 0); + auto arg1 = ops::_Arg(scope.WithOpName("b_0_arg"), DT_FLOAT, 1); + auto arg2 = ops::_Arg(scope.WithOpName("c_0_arg"), DT_INT32, 2); + auto arg3 = ops::_Arg(scope.WithOpName("d_0_arg"), DT_FLOAT, 3); + + auto arg4 = ops::_Arg(scope.WithOpName("u_0_arg"), DT_RESOURCE, 4); + auto arg5 = ops::_Arg(scope.WithOpName("v_0_arg"), DT_RESOURCE, 5); + auto arg6 = ops::_Arg(scope.WithOpName("w_0_arg"), DT_RESOURCE, 6); + + auto add_attrs = [](Node* node) { + node->AddAttr(EncapsulateXlaComputationsPass::kXlaClusterAttr, "launch0"); + node->set_requested_device("/gpu:0"); + }; + + auto b_identity = ops::Identity(scope.WithOpName("B_identity"), arg1); + add_attrs(b_identity.node()); + auto read_u = ops::ReadVariableOp(scope.WithOpName("ReadU"), arg4, DT_FLOAT); + add_attrs(read_u.node()); + auto read_v = ops::ReadVariableOp(scope.WithOpName("ReadV"), arg5, DT_FLOAT); + add_attrs(read_v.node()); + auto read_w = ops::ReadVariableOp(scope.WithOpName("ReadW"), arg6, DT_FLOAT); + add_attrs(read_w.node()); + + auto e = ops::Add(scope.WithOpName("E"), arg0, arg2); + add_attrs(e.node()); + auto f = ops::Add(scope.WithOpName("F"), read_v, read_w); + add_attrs(f.node()); + auto g = ops::Add(scope.WithOpName("G"), f, arg3); + add_attrs(g.node()); + + auto out0 = ops::_Retval(scope.WithOpName("b_identity_0_retval_RetVal"), + b_identity, 0); + auto out1 = ops::_Retval(scope.WithOpName("e_0_retval_RetVal"), e, 1); + auto out2 = ops::_Retval(scope.WithOpName("g_0_retval_RetVal"), g, 2); + auto out3 = + ops::_Retval(scope.WithOpName("readu_0_retval_RetVal"), read_u, 3); + + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_CHECK_OK(scope.ToGraph(graph.get())); + return graph; +} + +TEST(EncapsulateXlaComputations, DeterministicEncapsulate) { + // Test that control edge insertion order doesn't affect the cache key + // (cluster name) generated by TPU encapsulate pass. + auto get_serialized_graph = [](bool control_input_reversed, + bool operand_reversed) -> string { + FunctionLibraryDefinition flib_def(OpRegistry::Global(), {}); + std::unique_ptr graph(new Graph(&flib_def)); + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto a0 = ops::Placeholder(scope.WithOpName("A0"), DT_INT32); + auto a1 = ops::Placeholder(scope.WithOpName("A1"), DT_INT32); + + ops::Add e = operand_reversed ? ops::Add(scope.WithOpName("E"), a0, a1) + : ops::Add(scope.WithOpName("E"), a1, a0); + + auto add_attrs = [](Node* node) { + node->AddAttr(EncapsulateXlaComputationsPass::kXlaClusterAttr, + "launch0"); + }; + add_attrs(e.node()); + + TF_CHECK_OK(scope.ToGraph(graph.get())); + auto get_node_in_graph = [&graph](Node* node) { + return graph->FindNodeId(node->id()); + }; + // Insert control edge in different order. The order should not affect + // the encapsulated or serialized graph. + if (!control_input_reversed) { + graph->AddControlEdge(get_node_in_graph(a0.node()), + get_node_in_graph(e.node()), true); + graph->AddControlEdge(get_node_in_graph(a1.node()), + get_node_in_graph(e.node()), true); + } else { + graph->AddControlEdge(get_node_in_graph(a1.node()), + get_node_in_graph(e.node()), true); + graph->AddControlEdge(get_node_in_graph(a0.node()), + get_node_in_graph(e.node()), true); + } + } + TF_CHECK_OK(EncapsulateXlaComputationsPass::Encapsulate(&graph, &flib_def)); + GraphDef gdef; + graph->ToGraphDef(&gdef); + // Before serialization, sort control inputs first to remove + // nondeterminism. + SortControlInputs(&gdef); + string serialized; + SerializeToStringDeterministic(gdef, &serialized); + return serialized; + }; + + // Changing the order of control input shouldn't affect the graph generated. + EXPECT_EQ(get_serialized_graph(/*control_input_reversed=*/true, + /*operand_reversed=*/false), + get_serialized_graph(/*control_input_reversed=*/false, + /*operand_reversed=*/false)); + + // Changing the order of data input should affect the graph generated. + EXPECT_NE(get_serialized_graph(/*control_input_reversed=*/false, + /*operand_reversed=*/true), + get_serialized_graph(/*control_input_reversed=*/false, + /*operand_reversed=*/false)); +} + +TEST(EncapsulateXlaComputations, Encapsulate) { + FunctionLibraryDefinition flib_def(OpRegistry::Global(), {}); + std::unique_ptr graph(new Graph(&flib_def)); + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto a = ops::Placeholder(scope.WithOpName("A"), DT_INT32); + auto b = ops::Placeholder(scope.WithOpName("B"), DT_FLOAT); + auto c = ops::Placeholder(scope.WithOpName("C"), DT_INT32); + auto d = ops::Placeholder(scope.WithOpName("D"), DT_FLOAT); + auto u = ops::Placeholder(scope.WithOpName("U"), DT_RESOURCE); + auto v = ops::Placeholder(scope.WithOpName("V"), DT_RESOURCE); + auto w = ops::Placeholder(scope.WithOpName("W"), DT_RESOURCE); + + auto add_attrs = [](Node* node) { + node->AddAttr(EncapsulateXlaComputationsPass::kXlaClusterAttr, "launch0"); + node->set_requested_device("/gpu:0"); + }; + + auto b_identity = ops::Identity(scope.WithOpName("B_identity"), b); + add_attrs(b_identity.node()); + + auto read_u = ops::ReadVariableOp(scope.WithOpName("ReadU"), u, DT_FLOAT); + add_attrs(read_u.node()); + auto read_v = ops::ReadVariableOp(scope.WithOpName("ReadV"), v, DT_FLOAT); + add_attrs(read_v.node()); + auto read_w = ops::ReadVariableOp(scope.WithOpName("ReadW"), w, DT_FLOAT); + add_attrs(read_w.node()); + + auto e = ops::Add(scope.WithOpName("E"), a, c); + add_attrs(e.node()); + auto f = ops::Add(scope.WithOpName("F"), read_v, read_w); + add_attrs(f.node()); + auto g = ops::Add(scope.WithOpName("G"), f, d); + add_attrs(g.node()); + + auto out0 = ops::XlaClusterOutput(scope.WithOpName("Out0"), b_identity); + auto out1 = ops::XlaClusterOutput(scope.WithOpName("Out1"), e); + auto out2 = ops::XlaClusterOutput(scope.WithOpName("Out2"), g); + auto out3 = ops::XlaClusterOutput(scope.WithOpName("Out3"), read_u); + + auto consumer0_a = ops::Identity(scope.WithOpName("consumer0_a"), out0); + auto consumer0_b = ops::Identity(scope.WithOpName("consumer0_b"), out0); + auto consumer0_c = ops::Identity(scope.WithOpName("consumer0_c"), out0); + auto consumer1 = ops::Identity(scope.WithOpName("consumer1"), out1); + auto consumer2 = ops::Identity(scope.WithOpName("consumer2"), out2); + auto consumer3 = ops::Identity(scope.WithOpName("consumer3"), out3); + TF_ASSERT_OK(scope.ToGraph(graph.get())); + } + + std::unique_ptr graph_copy(new Graph(&flib_def)); + CopyGraph(*graph, graph_copy.get()); + + TF_ASSERT_OK(EncapsulateXlaComputationsPass::Encapsulate(&graph, &flib_def)); + + std::unordered_map index = graph->BuildNodeNameIndex(); + string function = index.at("launch0")->type_string(); + + // Tests the outer graph is as expected. + { + std::unique_ptr outer = MakeOuterGraph(flib_def, function); + GraphDef expected_def; + outer->ToGraphDef(&expected_def); + + GraphDef actual_def; + graph->ToGraphDef(&actual_def); + TF_EXPECT_GRAPH_EQ_INTERNAL(expected_def, actual_def); + } + + // Tests the encapsulated body graph is as expected. + { + std::unique_ptr body = MakeBodyGraph(); + GraphDef expected_body_def; + body->ToGraphDef(&expected_body_def); + + InstantiationResultForTest result; + TF_EXPECT_OK(InstantiateFunctionForTest(function, flib_def, &result)); + + EXPECT_EQ((DataTypeVector{DT_INT32, DT_FLOAT, DT_INT32, DT_FLOAT, + DT_RESOURCE, DT_RESOURCE, DT_RESOURCE}), + result.arg_types); + EXPECT_EQ((DataTypeVector{DT_FLOAT, DT_INT32, DT_FLOAT, DT_FLOAT}), + result.ret_types); + TF_EXPECT_GRAPH_EQ(expected_body_def, result.gdef); + } + + // Encapsulates the same computation again, verifies we reuse the same + // function. Encapsulation should be deterministic to avoid recompilation. + TF_ASSERT_OK( + EncapsulateXlaComputationsPass::Encapsulate(&graph_copy, &flib_def)); + std::unordered_map index_copy = + graph_copy->BuildNodeNameIndex(); + string function_copy = index_copy.at("launch0")->type_string(); + EXPECT_EQ(function, function_copy); +} + +TEST(EncapsulateXlaComputations, BuildXlaLaunchOp) { + std::unique_ptr body_graph = MakeBodyGraph(); + FunctionDefLibrary flib; + TF_ASSERT_OK(GraphToFunctionDef(*body_graph, "launch0", flib.add_function())); + + FunctionLibraryDefinition flib_def(OpRegistry::Global(), flib); + + std::unique_ptr graph = MakeOuterGraph(flib_def, "launch0"); + TF_ASSERT_OK(EncapsulateXlaComputationsPass::BuildXlaLaunchOps(graph.get())); + + Scope scope = Scope::DisabledShapeInferenceScope().ExitOnError(); + TF_EXPECT_OK(scope.graph()->AddFunctionLibrary(flib)); + + auto a = ops::Placeholder(scope.WithOpName("A"), DT_INT32); + auto b = ops::Placeholder(scope.WithOpName("B"), DT_FLOAT); + auto c = ops::Placeholder(scope.WithOpName("C"), DT_INT32); + auto d = ops::Placeholder(scope.WithOpName("D"), DT_FLOAT); + auto u = ops::Placeholder(scope.WithOpName("U"), DT_RESOURCE); + auto v = ops::Placeholder(scope.WithOpName("V"), DT_RESOURCE); + auto w = ops::Placeholder(scope.WithOpName("W"), DT_RESOURCE); + + NameAttrList function; + function.set_name("launch0"); + auto launch = ops::XlaLaunch( + scope.WithOpName("launch0").WithDevice("/gpu:0"), + std::initializer_list{}, std::initializer_list{a, b, c, d}, + std::initializer_list{u, v, w}, + DataTypeVector{DT_FLOAT, DT_INT32, DT_FLOAT, DT_FLOAT}, function); + + auto consumer0_a = + ops::Identity(scope.WithOpName("consumer0_a"), launch.results[0]); + auto consumer0_b = + ops::Identity(scope.WithOpName("consumer0_b"), launch.results[0]); + auto consumer0_c = + ops::Identity(scope.WithOpName("consumer0_c"), launch.results[0]); + auto consumer1 = + ops::Identity(scope.WithOpName("consumer1"), launch.results[1]); + auto consumer2 = + ops::Identity(scope.WithOpName("consumer2"), launch.results[2]); + auto consumer3 = + ops::Identity(scope.WithOpName("consumer3"), launch.results[3]); + + GraphDef expected_def; + TF_ASSERT_OK(scope.ToGraphDef(&expected_def)); + + GraphDef actual_def; + graph->ToGraphDef(&actual_def); + TF_EXPECT_GRAPH_EQ(expected_def, actual_def); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/graphcycles/BUILD b/tensorflow/compiler/jit/graphcycles/BUILD index 676f71a75aede2a7720ae0c8a579d64cc184509a..8212956adfeca263334e3d0d954f69e13c1ba28d 100644 --- a/tensorflow/compiler/jit/graphcycles/BUILD +++ b/tensorflow/compiler/jit/graphcycles/BUILD @@ -14,6 +14,7 @@ cc_library( hdrs = ["graphcycles.h"], deps = [ "//tensorflow/core:lib", + "@com_google_absl//absl/container:inlined_vector", ], ) diff --git a/tensorflow/compiler/jit/graphcycles/graphcycles.cc b/tensorflow/compiler/jit/graphcycles/graphcycles.cc index 805bbc62c1e2e877de87ab8faf3d60b829743df8..756377bd9502d7172b29f317c471963d1dee09a9 100644 --- a/tensorflow/compiler/jit/graphcycles/graphcycles.cc +++ b/tensorflow/compiler/jit/graphcycles/graphcycles.cc @@ -34,7 +34,7 @@ limitations under the License. #include #include -#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "absl/container/inlined_vector.h" #include "tensorflow/core/platform/logging.h" namespace tensorflow { @@ -44,7 +44,7 @@ namespace { typedef std::unordered_set NodeSet; template struct VecStruct { - typedef gtl::InlinedVector type; + typedef absl::InlinedVector type; }; template using Vec = typename VecStruct::type; diff --git a/tensorflow/compiler/jit/jit_compilation_pass_registration.cc b/tensorflow/compiler/jit/jit_compilation_pass_registration.cc index c37b6112cc8a92047d495d057f59e2281710e678..085c0e5adbb270e71ff3447a936555c99904e26c 100644 --- a/tensorflow/compiler/jit/jit_compilation_pass_registration.cc +++ b/tensorflow/compiler/jit/jit_compilation_pass_registration.cc @@ -13,14 +13,33 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/jit/build_xla_launch_ops_pass.h" +#include "tensorflow/compiler/jit/build_xla_ops_pass.h" #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" +#include "tensorflow/compiler/jit/encapsulate_xla_computations_pass.h" #include "tensorflow/compiler/jit/mark_for_compilation_pass.h" #include "tensorflow/compiler/jit/partially_decluster_pass.h" #include "tensorflow/core/common_runtime/optimization_registry.h" namespace tensorflow { +// PRE_PLACEMENT passes: + +// EncapsulateXlaComputationsPass rewrites computations generated by the +// xla.compile() Python code into XlaLaunch nodes. +REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 26, + EncapsulateXlaComputationsPass); + +// from +// third_party/tensorflow/compiler/tf2xla/functionalize_control_flow_pass_registration.cc +// FunctionalizeControlFlowPass: 27 +// +// This pass looks at the graph and all associated FunctionDefs, and turns +// traditional control flow structure (Switch/Merge/etc.) into functional +// control flow structure (XlaIf/XlaWhile). Following passes must +// handle those FunctionDef correctly. + +// POST_REWRITE_FOR_EXEC passes that support auto-clustering to enable XLA: + REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 10, MarkForCompilationPass); @@ -36,6 +55,6 @@ REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 30, // Must run after EncapsulateSubgraphsPass. REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 40, - BuildXlaLaunchOpsPass); + BuildXlaOpsPass); } // namespace tensorflow diff --git a/tensorflow/compiler/jit/kernels/BUILD b/tensorflow/compiler/jit/kernels/BUILD index 253a5d254792a19d98b75310ea6848f42597c0c7..26cb3af9d69ba1877c67853cde28d2477d394efc 100644 --- a/tensorflow/compiler/jit/kernels/BUILD +++ b/tensorflow/compiler/jit/kernels/BUILD @@ -7,9 +7,9 @@ package( ) cc_library( - name = "xla_launch_op", - srcs = ["xla_launch_op.cc"], - hdrs = ["xla_launch_op.h"], + name = "xla_ops", + srcs = ["xla_ops.cc"], + hdrs = ["xla_ops.h"], deps = [ "//tensorflow/compiler/jit:common", "//tensorflow/compiler/jit:xla_compilation_cache", @@ -26,6 +26,8 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core/kernels:variable_ops", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/memory", ], alwayslink = 1, ) diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_launch_op.cc deleted file mode 100644 index fde4135bf7f5f7bdede170d47fb2a76d1d6b3ae9..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/jit/kernels/xla_launch_op.cc +++ /dev/null @@ -1,287 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/jit/kernels/xla_launch_op.h" - -#include "tensorflow/compiler/jit/defs.h" -#include "tensorflow/compiler/jit/xla_device.h" -#include "tensorflow/compiler/jit/xla_launch_util.h" -#include "tensorflow/compiler/tf2xla/shape_util.h" -#include "tensorflow/compiler/tf2xla/tf2xla_util.h" -#include "tensorflow/compiler/tf2xla/xla_compiler.h" -#include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/core/common_runtime/dma_helper.h" -#include "tensorflow/core/common_runtime/function.h" -#include "tensorflow/core/framework/allocator.h" -#include "tensorflow/core/framework/node_def_util.h" -#include "tensorflow/core/framework/op.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/framework/types.h" -#include "tensorflow/core/kernels/variable_ops.h" -#include "tensorflow/core/platform/env.h" -#include "tensorflow/core/platform/stream_executor_no_cuda.h" -#include "tensorflow/core/util/stream_executor_util.h" - -namespace tensorflow { - -XlaLocalLaunchBase::XlaLocalLaunchBase(OpKernelConstruction* ctx, - const std::vector& constants, - const std::vector& resources, - const NameAttrList& function) - : OpKernel(ctx), - constants_(constants), - resources_(resources), - device_type_(ctx->device_type()), - function_(function) { - if (device_type_ == DeviceType(DEVICE_CPU)) { - platform_id_ = se::host::kHostPlatformId; - } else if (device_type_ == DeviceType(DEVICE_GPU)) { - platform_id_ = ctx->device() - ->tensorflow_gpu_device_info() - ->stream->parent() - ->platform() - ->id(); - } else { - platform_id_ = nullptr; - } -} - -Status XlaLocalLaunchBase::BuildCompilationCache(OpKernelContext* ctx, - XlaCompilationCache** cache) { - const XlaDevice::Metadata* metadata; - Status s = XlaDevice::GetMetadata(ctx, &metadata); - if (s.ok()) { - *cache = new XlaCompilationCache(metadata->client(), - metadata->jit_device_type()); - return Status::OK(); - } - - auto platform = se::MultiPlatformManager::PlatformWithId(platform_id_); - if (!platform.ok()) { - return platform.status(); - } - xla::LocalClientOptions client_options; - client_options.set_platform(platform.ValueOrDie()); - client_options.set_intra_op_parallelism_threads( - ctx->device()->tensorflow_cpu_worker_threads()->num_threads); - auto client = xla::ClientLibrary::GetOrCreateLocalClient(client_options); - if (!client.ok()) { - return client.status(); - } - const XlaOpRegistry::DeviceRegistration* registration; - if (!XlaOpRegistry::GetCompilationDevice(device_type_.type(), - ®istration)) { - return errors::InvalidArgument("No JIT device registered for ", - device_type_.type()); - } - *cache = new XlaCompilationCache( - client.ValueOrDie(), DeviceType(registration->compilation_device_name)); - return Status::OK(); -} - -void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { - VLOG(1) << "XlaLocalLaunchOpBase::Compute " - << Canonicalize(function_.name(), AttrSlice(&function_.attr())); - // We store information about the JIT-compiled XLA computation - // in the ResourceMgr. - ResourceMgr* rm = ctx->resource_manager(); - OP_REQUIRES(ctx, rm, errors::Internal("No resource manager.")); - - se::Stream* stream = - ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr; - - XlaCompilationCache* cache; - OP_REQUIRES_OK(ctx, rm->LookupOrCreate( - rm->default_container(), "xla_cache", &cache, - [this, ctx](XlaCompilationCache** cache) { - return BuildCompilationCache(ctx, cache); - })); - // Hold the reference to the JIT during evaluation. (We could probably - // free it sooner because the ResourceMgr will retain a reference, but - // this is more obviously correct.) - core::ScopedUnref cache_ref(cache); - - const XlaDevice::Metadata* metadata = nullptr; - Status s = XlaDevice::GetMetadata(ctx, &metadata); - bool allocate_xla_tensors = s.ok(); - bool use_multiple_streams = s.ok() && metadata->UseMultipleStreams(); - - // Get the platform_id_ for XLA_* devices. - if (platform_id_ == nullptr) { - if (s.ok()) { - platform_id_ = metadata->platform()->id(); - } - } - - std::map variables = - SnapshotResourceVariables(ctx, resources_); - - xla::LocalClient* client = static_cast(cache->client()); - - XlaAllocator local_xla_allocator(client->backend().platform(), - ctx->device()->GetAllocator({})); - xla::DeviceMemoryAllocator* xla_allocator; - // If we are on an XlaDevice, use the underlying XLA platform's allocator - // directly. We could use the StreamExecutor's allocator which may - // theoretically be more correct, but XLA returns a nice OOM message in a - // Status and StreamExecutor does not. - // - // Importantly we can't use ctx->device()->GetAllocator() as the allocator - // (which local_xla_allocator above uses) as on an XlaDevice, this is a - // dummy allocator that returns XlaTensor objects. The XlaCompiler needs a - // real allocator to allocate real buffers. - if (allocate_xla_tensors) { - xla_allocator = client->backend().memory_allocator(); - } else { - xla_allocator = &local_xla_allocator; - } - - XlaCompiler::Options options; - options.client = client; - if (ctx->op_device_context() != nullptr) { - options.device_ordinal = - ctx->op_device_context()->stream()->parent()->device_ordinal(); - } - options.device_type = cache->device_type(); - options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition(); - options.graph_def_version = ctx->function_library()->graph_def_version(); - options.allow_cpu_custom_calls = (platform_id_ == se::host::kHostPlatformId); - options.device_allocator = xla_allocator; - if (metadata) { - options.shape_representation_fn = metadata->shape_representation_fn(); - } - - const XlaCompiler::CompilationResult* kernel; - xla::LocalExecutable* executable; - - std::map constant_args; - for (int i : constants_) { - constant_args.insert({i, ctx->input(i)}); - } - XlaCompiler::CompileOptions compile_options; - compile_options.is_entry_computation = true; - // If we resolve constants we never emit them on the device, meaning that if - // they are needed by a following computation the host has to transfer - // them. Not resolving constants is expected to be faster than resolving - // constants. - compile_options.resolve_compile_time_constants = true; - // Optimization: where possible, have the computation return a naked array - // rather than a one-element tuple. - compile_options.always_return_tuple = false; - - OP_REQUIRES_OK( - ctx, cache->Compile(options, function_, constant_args, variables, ctx, - &kernel, &executable, compile_options)); - - VLOG(1) << "Executing XLA Computation..."; - - XlaComputationLaunchContext launch_context( - client, xla_allocator, allocate_xla_tensors, use_multiple_streams); - launch_context.PopulateInputs(ctx, kernel, variables); - - // Execute the computation. - VLOG(2) << "Executing computation."; - xla::ExecutableRunOptions run_options; - run_options.set_stream(stream); - run_options.set_allocator(xla_allocator); - run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device()); - run_options.set_rng_seed(GetXLARandomSeed()); - Env* env = Env::Default(); - auto start_time = env->NowMicros(); - - auto run_result = executable->Run(launch_context.arguments(), run_options); - OP_REQUIRES(ctx, run_result.ok(), run_result.status()); - - auto elapsed = env->NowMicros() - start_time; - VLOG(2) << "Elapsed time: " << elapsed << "us"; - - OP_REQUIRES_OK(ctx, launch_context.PopulateOutputs( - ctx, kernel, run_result.ConsumeValueOrDie())); - VLOG(1) << "Done"; -} - -namespace { - -// OP_REQUIRES_OK_RETURN is the same as OP_REQUIRES_OK except that -// in error case, it returns RET instead of void. -#define OP_REQUIRES_OK_RETURN(CTX, RET, ...) \ - do { \ - ::tensorflow::Status _s(__VA_ARGS__); \ - if (!TF_PREDICT_TRUE(_s.ok())) { \ - (CTX)->CtxFailureWithWarning(__FILE__, __LINE__, _s); \ - return RET; \ - } \ - } while (0) - -// Helper static functions to construct parameters for -// XlaLocalLaunchBase constructor from OpKernelConstruction. -std::vector ConstantsVector(OpKernelConstruction* ctx) { - DataTypeVector constant_types; - OP_REQUIRES_OK_RETURN(ctx, std::vector(), - ctx->GetAttr("Tconstants", &constant_types)); - std::vector constants(constant_types.size()); - std::iota(constants.begin(), constants.end(), 0); - return constants; -} - -std::vector ResourcesVector(OpKernelConstruction* ctx) { - DataTypeVector constant_types; - OP_REQUIRES_OK_RETURN(ctx, std::vector(), - ctx->GetAttr("Tconstants", &constant_types)); - - DataTypeVector arg_types; - OP_REQUIRES_OK_RETURN(ctx, std::vector(), - ctx->GetAttr("Targs", &arg_types)); - - int num_resources; - OP_REQUIRES_OK_RETURN(ctx, std::vector(), - ctx->GetAttr("Nresources", &num_resources)); - - std::vector resources(num_resources); - std::iota(resources.begin(), resources.end(), - constant_types.size() + arg_types.size()); - return resources; -} - -NameAttrList FunctionAttr(OpKernelConstruction* ctx) { - const NameAttrList* func; - OP_REQUIRES_OK_RETURN(ctx, NameAttrList(), ctx->GetAttr("function", &func)); - return *func; -} - -#undef OP_REQUIRES_OK_RETURN -} // namespace - -XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx) - : XlaLocalLaunchBase(ctx, ConstantsVector(ctx), ResourcesVector(ctx), - FunctionAttr(ctx)) {} - -XlaLocalLaunchOp::~XlaLocalLaunchOp() { - VLOG(1) << "XlaLocalLaunchOp destroyed"; -} - -REGISTER_KERNEL_BUILDER(Name("XlaLaunch").Device(DEVICE_CPU), XlaLocalLaunchOp); - -REGISTER_KERNEL_BUILDER(Name("XlaLaunch") - .Device(DEVICE_GPU) - .HostMemory("constants") - .HostMemory("resources"), - XlaLocalLaunchOp); - -} // namespace tensorflow diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.h b/tensorflow/compiler/jit/kernels/xla_launch_op.h deleted file mode 100644 index bf1e99066897b185471129130cbefaa505e5f8b2..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/jit/kernels/xla_launch_op.h +++ /dev/null @@ -1,84 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_JIT_KERNELS_XLA_LAUNCH_OP_H_ -#define TENSORFLOW_COMPILER_JIT_KERNELS_XLA_LAUNCH_OP_H_ - -#include "tensorflow/compiler/jit/xla_compilation_cache.h" -#include "tensorflow/core/framework/allocator.h" -#include "tensorflow/core/framework/op.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/platform/macros.h" -#include "tensorflow/core/util/stream_executor_util.h" - -namespace tensorflow { - -// XlaLocalLaunchBase is almost the same as XlaLocalLaunchOp. -// The only difference is that it does not require arguments to follow -// the "constants, then regular args, then resources" order. -// It takes vectors of constant and resource arguments explicitly. -// It does not have corresponding OpDef because it is never present -// in the GraphDef. -// Currently, it is used by eager runtime. FunctionLibraryRuntime creates -// this kernel when asked to create a kernel for an XLA-compiled function. -class XlaLocalLaunchBase : public OpKernel { - public: - XlaLocalLaunchBase(OpKernelConstruction* ctx, - const std::vector& constants, - const std::vector& resources, - const NameAttrList& function); - XlaLocalLaunchBase(const XlaLocalLaunchBase&) = delete; - XlaLocalLaunchBase& operator=(const XlaLocalLaunchBase&) = delete; - ~XlaLocalLaunchBase() override = default; - - void Compute(OpKernelContext* ctx) override; - - protected: - // Builds a XlaCompilationCache class suitable for the current device. - Status BuildCompilationCache(OpKernelContext* ctx, - XlaCompilationCache** cache); - - // Indexes of compile-time constant inputs - std::vector constants_; - // Indexes of resource inputs - std::vector resources_; - - DeviceType device_type_; - NameAttrList function_; - se::Platform::Id platform_id_; -}; - -// XlaLocalLaunchOp is used to replace a region of the TensorFlow graph -// which will be compiled and executed using XLA. The XlaLocalLaunchOp is -// responsible for handling interactions with the TensorFlow executor. -// Once all inputs are present, and their shapes are known, the op can -// use a 'XlaCompilationCache' to compile and execute code which is specific -// to the shapes of input Tensors. -// XlaLocalLaunchOp uses xla::LocalClient::Compile() and -// xla::LocalExecutable::Run(), and passes arguments into/out of XLA in device -// memory. -class XlaLocalLaunchOp : public XlaLocalLaunchBase { - public: - explicit XlaLocalLaunchOp(OpKernelConstruction* ctx); - ~XlaLocalLaunchOp() override; - - private: - TF_DISALLOW_COPY_AND_ASSIGN(XlaLocalLaunchOp); -}; - -} // namespace tensorflow - -#endif // TENSORFLOW_COMPILER_JIT_KERNELS_XLA_LAUNCH_OP_H_ diff --git a/tensorflow/compiler/jit/kernels/xla_ops.cc b/tensorflow/compiler/jit/kernels/xla_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..2268d9042860f6556cb69469ee52ad7cbbb81954 --- /dev/null +++ b/tensorflow/compiler/jit/kernels/xla_ops.cc @@ -0,0 +1,518 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/kernels/xla_ops.h" + +#include "absl/container/flat_hash_map.h" +#include "absl/memory/memory.h" +#include "tensorflow/compiler/jit/defs.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/tf2xla_util.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/common_runtime/dma_helper.h" +#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/variable_ops.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/core/util/stream_executor_util.h" + +namespace tensorflow { + +namespace { + +Status PlatformInfoFromContext(OpKernelConstruction* ctx, + XlaPlatformInfo* result) { + DeviceType device_type = ctx->device_type(); + se::Platform::Id platform_id = nullptr; + const XlaDevice::Metadata* xla_device_metadata = nullptr; + std::unique_ptr xla_allocator; + xla::DeviceMemoryAllocator* device_allocator = nullptr; + + if (ctx->device_type() == DeviceType(DEVICE_CPU)) { + platform_id = se::host::kHostPlatformId; + } else if (ctx->device_type() == DeviceType(DEVICE_GPU)) { + platform_id = ctx->device() + ->tensorflow_gpu_device_info() + ->stream->parent() + ->platform() + ->id(); + } else if (XlaDevice::GetMetadata(ctx, &xla_device_metadata).ok()) { + // If we are on an XlaDevice, use the underlying XLA platform's allocator + // directly. We could use the StreamExecutor's allocator which may + // theoretically be more correct, but XLA returns a nice OOM message in a + // Status and StreamExecutor does not. + // + // Importantly we can't use ctx->device()->GetAllocator() as the allocator + // (which xla_allocator above uses) as on an XlaDevice, this is a dummy + // allocator that returns XlaTensor objects. The XlaCompiler needs a real + // allocator to allocate real buffers. + + platform_id = xla_device_metadata->platform()->id(); + device_allocator = + xla_device_metadata->client()->backend().memory_allocator(); + } + + if (!device_allocator) { + TF_ASSIGN_OR_RETURN(se::Platform* const platform, + se::MultiPlatformManager::PlatformWithId(platform_id)); + xla_allocator = absl::make_unique( + platform, ctx->device()->GetAllocator({})); + } + + *result = XlaPlatformInfo(device_type, platform_id, xla_device_metadata, + std::move(xla_allocator), device_allocator); + + return Status::OK(); +} + +// A closure describing how to run a compiled version of a TensorFlow function. +// +// It may seem unusual to stick the resource variable snapshots in this class. +// This is necessary: we need to use the snapshots observed by the compiler as +// the initial values for the resource variables (and cannot snapshot them again +// during execution) because otherwise we risk observing a different snapshot +// with shapes different from what we compiled for. +class XlaExecutableClosure { + public: + explicit XlaExecutableClosure( + xla::LocalClient* client, xla::LocalExecutable* executable, + const XlaCompiler::CompilationResult* compilation_result, + std::map resource_var_snapshots, + int num_constant_args) + : client_(client), + executable_(executable), + compilation_result_(compilation_result), + resource_var_snapshots_(std::move(resource_var_snapshots)), + num_constant_args_(num_constant_args) {} + + XlaExecutableClosure(XlaExecutableClosure&&) = default; + XlaExecutableClosure& operator=(XlaExecutableClosure&&) = default; + + xla::LocalClient* client() const { return client_; } + xla::LocalExecutable* executable() const { return executable_; } + const XlaCompiler::CompilationResult* compilation_result() const { + return compilation_result_; + } + const std::map& resource_var_snapshots() const { + return resource_var_snapshots_; + } + int num_constant_args() const { return num_constant_args_; } + + private: + xla::LocalClient* client_; + xla::LocalExecutable* executable_; + const XlaCompiler::CompilationResult* compilation_result_; + std::map resource_var_snapshots_; + int num_constant_args_; + + TF_DISALLOW_COPY_AND_ASSIGN(XlaExecutableClosure); +}; + +// This maintains a mapping from a globally unique ID to XlaExecutableClosure +// instances. +class XlaExecutableClosureStore { + public: + XlaExecutableClosureStore() : key_counter_(0) {} + + using KeyT = string; + + KeyT Produce(XlaExecutableClosure result) { + mutex_lock l(mutex_); + KeyT key = absl::StrCat(key_counter_++); + bool insert_successful = closures_.emplace(key, std::move(result)).second; + DCHECK(insert_successful); + (void)insert_successful; + return key; + } + + XlaExecutableClosure Consume(const KeyT& key) { + mutex_lock l(mutex_); + auto it = closures_.find(key); + DCHECK(it != closures_.end()); + XlaExecutableClosure value = std::move(it->second); + closures_.erase(it); + return value; + } + + static XlaExecutableClosureStore* Global() { + static XlaExecutableClosureStore* instance = new XlaExecutableClosureStore; + return instance; + } + + private: + mutex mutex_; + int64 key_counter_ GUARDED_BY(mutex_); + absl::flat_hash_map closures_ GUARDED_BY(mutex_); + + TF_DISALLOW_COPY_AND_ASSIGN(XlaExecutableClosureStore); +}; + +} // namespace + +XlaLocalLaunchBase::XlaLocalLaunchBase(OpKernelConstruction* ctx, + const std::vector& constants, + const std::vector& resources, + const NameAttrList& function) + : OpKernel(ctx), + constants_(constants), + resources_(resources), + function_(function) { + OP_REQUIRES_OK(ctx, PlatformInfoFromContext(ctx, &platform_info_)); +} + +static Status BuildCompilationCache(OpKernelContext* ctx, + const XlaPlatformInfo& platform_info, + XlaCompilationCache** cache) { + if (platform_info.xla_device_metadata()) { + *cache = new XlaCompilationCache( + platform_info.xla_device_metadata()->client(), + platform_info.xla_device_metadata()->jit_device_type()); + return Status::OK(); + } + + auto platform = + se::MultiPlatformManager::PlatformWithId(platform_info.platform_id()); + if (!platform.ok()) { + return platform.status(); + } + xla::LocalClientOptions client_options; + client_options.set_platform(platform.ValueOrDie()); + client_options.set_intra_op_parallelism_threads( + ctx->device()->tensorflow_cpu_worker_threads()->num_threads); + auto client = xla::ClientLibrary::GetOrCreateLocalClient(client_options); + if (!client.ok()) { + return client.status(); + } + const XlaOpRegistry::DeviceRegistration* registration; + if (!XlaOpRegistry::GetCompilationDevice(platform_info.device_type().type(), + ®istration)) { + return errors::InvalidArgument("No JIT device registered for ", + platform_info.device_type().type()); + } + *cache = new XlaCompilationCache( + client.ValueOrDie(), DeviceType(registration->compilation_device_name)); + return Status::OK(); +} + +static Status CompileToLocalExecutable( + OpKernelContext* ctx, const NameAttrList& function, + const XlaPlatformInfo& platform_info, absl::Span resources, + absl::Span constants, bool lazy, xla::LocalClient** client, + std::map* variables, + const XlaCompiler::CompilationResult** kernel, + xla::LocalExecutable** executable) { + // We store information about the JIT-compiled XLA computation + // in the ResourceMgr. + ResourceMgr* rm = ctx->resource_manager(); + if (!rm) { + return errors::Internal("No resource manager."); + } + + XlaCompilationCache* cache; + TF_RETURN_IF_ERROR(rm->LookupOrCreate( + rm->default_container(), "xla_cache", &cache, + [&](XlaCompilationCache** cache) { + return BuildCompilationCache(ctx, platform_info, cache); + })); + // Hold the reference to the JIT during evaluation. (We could probably + // free it sooner because the ResourceMgr will retain a reference, but + // this is more obviously correct.) + core::ScopedUnref cache_ref(cache); + + *variables = SnapshotResourceVariables(ctx, resources); + *client = static_cast(cache->client()); + + XlaCompiler::Options options; + options.client = *client; + if (ctx->op_device_context() != nullptr) { + options.device_ordinal = + ctx->op_device_context()->stream()->parent()->device_ordinal(); + } + options.device_type = cache->device_type(); + options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition(); + options.graph_def_version = ctx->function_library()->graph_def_version(); + options.allow_cpu_custom_calls = + (platform_info.platform_id() == se::host::kHostPlatformId); + options.device_allocator = platform_info.allocator(); + if (platform_info.xla_device_metadata()) { + options.shape_representation_fn = + platform_info.xla_device_metadata()->shape_representation_fn(); + } + + std::map constant_args; + for (int i : constants) { + constant_args.insert({i, ctx->input(i)}); + } + XlaCompiler::CompileOptions compile_options; + compile_options.is_entry_computation = true; + // If we resolve constants we never emit them on the device, meaning that if + // they are needed by a following computation the host has to transfer + // them. Not resolving constants is expected to be faster than resolving + // constants. + compile_options.resolve_compile_time_constants = true; + // Optimization: where possible, have the computation return a naked array + // rather than a one-element tuple. + compile_options.always_return_tuple = false; + + return cache->Compile(options, function, constant_args, *variables, ctx, + compile_options, + lazy ? XlaCompilationCache::CompileMode::kLazy + : XlaCompilationCache::CompileMode::kStrict, + kernel, executable); +} + +void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { + VLOG(1) << "XlaLocalLaunchOpBase::Compute " + << Canonicalize(function_.name(), AttrSlice(&function_.attr())); + + xla::LocalClient* client; + const XlaCompiler::CompilationResult* kernel; + xla::LocalExecutable* executable; + std::map variables; + + OP_REQUIRES_OK( + ctx, CompileToLocalExecutable(ctx, function_, platform_info_, resources_, + constants_, /*lazy=*/false, &client, + &variables, &kernel, &executable)); + + se::Stream* stream = + ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr; + + VLOG(1) << "Executing XLA Computation..."; + + XlaComputationLaunchContext launch_context( + client, platform_info_.allocator(), + /*allocate_xla_tensors=*/platform_info_.is_on_xla_device(), + platform_info_.UseMultipleStreams()); + launch_context.PopulateInputs(ctx, kernel, variables, + /*missing_ctx_input_prefix=*/0); + + // Execute the computation. + VLOG(2) << "Executing computation."; + xla::ExecutableRunOptions run_options; + run_options.set_stream(stream); + run_options.set_allocator(platform_info_.allocator()); + run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device()); + run_options.set_rng_seed(GetXLARandomSeed()); + Env* env = Env::Default(); + auto start_time = env->NowMicros(); + + auto run_result = executable->Run(launch_context.arguments(), run_options); + OP_REQUIRES(ctx, run_result.ok(), run_result.status()); + + auto elapsed = env->NowMicros() - start_time; + VLOG(2) << "Elapsed time: " << elapsed << "us"; + + OP_REQUIRES_OK(ctx, launch_context.PopulateOutputs( + ctx, kernel, run_result.ConsumeValueOrDie(), + /*missing_ctx_input_prefix=*/0)); + VLOG(1) << "Done"; +} + +namespace { + +// OP_REQUIRES_OK_RETURN is the same as OP_REQUIRES_OK except that +// in error case, it returns RET instead of void. +#define OP_REQUIRES_OK_RETURN(CTX, RET, ...) \ + do { \ + ::tensorflow::Status _s(__VA_ARGS__); \ + if (!TF_PREDICT_TRUE(_s.ok())) { \ + (CTX)->CtxFailureWithWarning(__FILE__, __LINE__, _s); \ + return RET; \ + } \ + } while (0) + +// Helper static functions to construct parameters for +// XlaLocalLaunchBase constructor from OpKernelConstruction. +std::vector ConstantsVector(OpKernelConstruction* ctx) { + DataTypeVector constant_types; + OP_REQUIRES_OK_RETURN(ctx, std::vector(), + ctx->GetAttr("Tconstants", &constant_types)); + std::vector constants(constant_types.size()); + std::iota(constants.begin(), constants.end(), 0); + return constants; +} + +std::vector ResourcesVector(OpKernelConstruction* ctx) { + DataTypeVector constant_types; + OP_REQUIRES_OK_RETURN(ctx, std::vector(), + ctx->GetAttr("Tconstants", &constant_types)); + + DataTypeVector arg_types; + OP_REQUIRES_OK_RETURN(ctx, std::vector(), + ctx->GetAttr("Targs", &arg_types)); + + int num_resources; + OP_REQUIRES_OK_RETURN(ctx, std::vector(), + ctx->GetAttr("Nresources", &num_resources)); + + std::vector resources(num_resources); + std::iota(resources.begin(), resources.end(), + constant_types.size() + arg_types.size()); + return resources; +} + +NameAttrList FunctionAttr(OpKernelConstruction* ctx) { + const NameAttrList* func; + OP_REQUIRES_OK_RETURN(ctx, NameAttrList(), ctx->GetAttr("function", &func)); + return *func; +} + +#undef OP_REQUIRES_OK_RETURN +} // namespace + +XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx) + : XlaLocalLaunchBase(ctx, ConstantsVector(ctx), ResourcesVector(ctx), + FunctionAttr(ctx)) {} + +XlaLocalLaunchOp::~XlaLocalLaunchOp() { + VLOG(1) << "XlaLocalLaunchOp destroyed"; +} + +XlaCompileOp::XlaCompileOp(OpKernelConstruction* ctx) + : OpKernel(ctx), + constants_(ConstantsVector(ctx)), + resources_(ResourcesVector(ctx)), + function_(FunctionAttr(ctx)) { + OP_REQUIRES_OK(ctx, PlatformInfoFromContext(ctx, &platform_info_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("must_compile", &must_compile_)); +} + +void XlaCompileOp::Compute(OpKernelContext* ctx) { + VLOG(3) << "XlaCompileOp " << def().name() + << (must_compile_ ? "(must-compile)" : ""); + xla::LocalClient* client; + const XlaCompiler::CompilationResult* kernel; + xla::LocalExecutable* executable; + std::map variables; + + OP_REQUIRES_OK( + ctx, CompileToLocalExecutable(ctx, function_, platform_info_, resources_, + constants_, /*lazy=*/!must_compile_, + &client, &variables, &kernel, &executable)); + + AllocatorAttributes host_alloc_attrs; + host_alloc_attrs.set_gpu_compatible(true); + host_alloc_attrs.set_on_host(true); + Allocator* cpu_allocator = ctx->device()->GetAllocator(host_alloc_attrs); + + if (!executable) { + DCHECK(!must_compile_); + Tensor compilation_key(cpu_allocator, DT_STRING, TensorShape({})); + + Tensor compilation_successful(cpu_allocator, DT_BOOL, TensorShape({})); + compilation_successful.scalar()() = false; + ctx->set_output(0, Tensor(cpu_allocator, DT_STRING, TensorShape({}))); + ctx->set_output(1, compilation_successful); + return; + } + + // Each execution of an XlaCompile op creates a new XlaExecutableClosure, even + // if it didn't have to compile the cluster because of a compilation-cache + // hit. This is because we at least need new snapshots of the resource + // variables. + XlaExecutableClosureStore::KeyT key = + XlaExecutableClosureStore::Global()->Produce(XlaExecutableClosure( + client, executable, kernel, std::move(variables), constants_.size())); + + Tensor compilation_key(cpu_allocator, DT_STRING, TensorShape({})); + compilation_key.flat()(0) = key; + + Tensor compilation_successful(cpu_allocator, DT_BOOL, TensorShape({})); + compilation_successful.flat()(0) = true; + + ctx->set_output(0, compilation_key); + ctx->set_output(1, compilation_successful); +} + +XlaRunOp::XlaRunOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, PlatformInfoFromContext(ctx, &platform_info_)); +} + +void XlaRunOp::Compute(OpKernelContext* ctx) { + VLOG(3) << "XlaRunOp " << def().name(); + Tensor key_tensor = ctx->input(ctx->num_inputs() - 1); + const XlaExecutableClosureStore::KeyT& key = key_tensor.flat()(0); + + XlaExecutableClosure closure = + XlaExecutableClosureStore::Global()->Consume(key); + + XlaComputationLaunchContext launch_context( + closure.client(), platform_info_.allocator(), + /*allocate_xla_tensors=*/platform_info_.is_on_xla_device(), + /*use_multiple_streams=*/platform_info_.UseMultipleStreams()); + + // We're missing the must-be-constant inputs, tell `PopulateInputs` + // about this. We don't actually need these inputs because they've + // already been baked into the compiled kernel. + launch_context.PopulateInputs( + ctx, closure.compilation_result(), closure.resource_var_snapshots(), + /*missing_ctx_input_prefix=*/closure.num_constant_args()); + + se::Stream* stream = + ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr; + xla::ExecutableRunOptions run_options; + run_options.set_stream(stream); + run_options.set_allocator(platform_info_.allocator()); + run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device()); + run_options.set_rng_seed(GetXLARandomSeed()); + Env* env = Env::Default(); + auto start_time = env->NowMicros(); + + auto run_result = + closure.executable()->Run(launch_context.arguments(), run_options); + OP_REQUIRES(ctx, run_result.ok(), run_result.status()); + + auto elapsed = env->NowMicros() - start_time; + VLOG(2) << "Elapsed time in computation: " << elapsed << "us"; + + OP_REQUIRES_OK( + ctx, + launch_context.PopulateOutputs( + ctx, closure.compilation_result(), run_result.ConsumeValueOrDie(), + /*missing_ctx_input_prefix=*/closure.num_constant_args())); +} + +REGISTER_KERNEL_BUILDER(Name("XlaLaunch").Device(DEVICE_CPU), XlaLocalLaunchOp); + +REGISTER_KERNEL_BUILDER(Name("XlaLaunch") + .Device(DEVICE_GPU) + .HostMemory("constants") + .HostMemory("resources"), + XlaLocalLaunchOp); + +REGISTER_KERNEL_BUILDER(Name("_XlaCompile").Device(DEVICE_CPU), XlaCompileOp); +REGISTER_KERNEL_BUILDER(Name("_XlaCompile") + .Device(DEVICE_GPU) + .HostMemory("constants") + .HostMemory("key") + .HostMemory("compilation_successful") + .HostMemory("resources"), + XlaCompileOp); + +REGISTER_KERNEL_BUILDER(Name("_XlaRun").Device(DEVICE_CPU), XlaRunOp); +REGISTER_KERNEL_BUILDER(Name("_XlaRun").Device(DEVICE_GPU), XlaRunOp); + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/kernels/xla_ops.h b/tensorflow/compiler/jit/kernels/xla_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..ac90837e0d90943b93e2cdb01a30fa0837ba94df --- /dev/null +++ b/tensorflow/compiler/jit/kernels/xla_ops.h @@ -0,0 +1,170 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_KERNELS_XLA_OPS_H_ +#define TENSORFLOW_COMPILER_JIT_KERNELS_XLA_OPS_H_ + +#include "tensorflow/compiler/jit/xla_compilation_cache.h" +#include "tensorflow/compiler/jit/xla_device.h" +#include "tensorflow/compiler/jit/xla_launch_util.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/util/stream_executor_util.h" + +namespace tensorflow { + +// Holds some information about the platform on which an +// XlaLaunch/_XlaCompile/_XlaRun op must run on. +class XlaPlatformInfo { + public: + XlaPlatformInfo() : device_type_("") {} + explicit XlaPlatformInfo(const DeviceType device_type, + se::Platform::Id platform_id, + const XlaDevice::Metadata* xla_device_metadata, + std::unique_ptr xla_allocator, + xla::DeviceMemoryAllocator* device_allocator) + : device_type_(device_type), + platform_id_(platform_id), + xla_device_metadata_(xla_device_metadata), + xla_allocator_(std::move(xla_allocator)), + device_allocator_(device_allocator) { + CHECK((device_allocator_ != nullptr) ^ (xla_allocator_.get() != nullptr)); + } + + XlaPlatformInfo& operator=(XlaPlatformInfo&& other) = default; + + bool UseMultipleStreams() const { + return xla_device_metadata_ && xla_device_metadata_->UseMultipleStreams(); + } + + xla::DeviceMemoryAllocator* allocator() const { + return device_allocator_ ? device_allocator_ : xla_allocator_.get(); + } + DeviceType device_type() const { return device_type_; } + + // This is equal to xla_device_metadata()->platform()->id() if + // xla_device_metadata() is not nullptr. + se::Platform::Id platform_id() const { return platform_id_; } + + // This may be null if the op this XlaPlatformInfo is for was not placed on an + // XLA device. + const XlaDevice::Metadata* xla_device_metadata() const { + return xla_device_metadata_; + } + bool is_on_xla_device() const { return xla_device_metadata() != nullptr; } + + private: + DeviceType device_type_; + se::Platform::Id platform_id_; + + // xla_device_metadata_ lives in the tensorflow::DeviceBase in which the + // XlaLaunch/_XlaCompile/_XlaRun op is placed and thus does not die before the + // XlaLaunch/_XlaCompile/_XlaRun OpKernel. + const XlaDevice::Metadata* xla_device_metadata_; + + // If the op associated with this XlaPlatformInfo is placed on an XLA device + // then device_allocator_ is the xla::Backend's memory allocator and + // xla_allocator_ is null. If the op is placed on a regular CPU or GPU device + // then device_allocator_ is null and xla_allocator_ points to an appropriate + // XlaAllocator instance. + std::unique_ptr xla_allocator_; + xla::DeviceMemoryAllocator* device_allocator_; + + TF_DISALLOW_COPY_AND_ASSIGN(XlaPlatformInfo); +}; + +// XlaLocalLaunchBase is almost the same as XlaLocalLaunchOp. +// The only difference is that it does not require arguments to follow +// the "constants, then regular args, then resources" order. +// It takes vectors of constant and resource arguments explicitly. +// It does not have corresponding OpDef because it is never present +// in the GraphDef. +// Currently, it is used by eager runtime. FunctionLibraryRuntime creates +// this kernel when asked to create a kernel for an XLA-compiled function. +class XlaLocalLaunchBase : public OpKernel { + public: + XlaLocalLaunchBase(OpKernelConstruction* ctx, + const std::vector& constants, + const std::vector& resources, + const NameAttrList& function); + XlaLocalLaunchBase(const XlaLocalLaunchBase&) = delete; + XlaLocalLaunchBase& operator=(const XlaLocalLaunchBase&) = delete; + ~XlaLocalLaunchBase() override = default; + + void Compute(OpKernelContext* ctx) override; + + protected: + // Indexes of compile-time constant inputs + std::vector constants_; + // Indexes of resource inputs + std::vector resources_; + + NameAttrList function_; + XlaPlatformInfo platform_info_; +}; + +// XlaLocalLaunchOp is used to replace a region of the TensorFlow graph +// which will be compiled and executed using XLA. The XlaLocalLaunchOp is +// responsible for handling interactions with the TensorFlow executor. +// Once all inputs are present, and their shapes are known, the op can +// use a 'XlaCompilationCache' to compile and execute code which is specific +// to the shapes of input Tensors. +// XlaLocalLaunchOp uses xla::LocalClient::Compile() and +// xla::LocalExecutable::Run(), and passes arguments into/out of XLA in device +// memory. +class XlaLocalLaunchOp : public XlaLocalLaunchBase { + public: + explicit XlaLocalLaunchOp(OpKernelConstruction* ctx); + ~XlaLocalLaunchOp() override; + + private: + TF_DISALLOW_COPY_AND_ASSIGN(XlaLocalLaunchOp); +}; + +class XlaCompileOp : public OpKernel { + public: + explicit XlaCompileOp(OpKernelConstruction* ctx); + + void Compute(OpKernelContext* ctx) override; + + private: + // Indexes of compile-time constant inputs + std::vector constants_; + // Indexes of resource inputs + std::vector resources_; + + NameAttrList function_; + + XlaPlatformInfo platform_info_; + + bool must_compile_; +}; + +class XlaRunOp : public OpKernel { + public: + explicit XlaRunOp(OpKernelConstruction* ctx); + + void Compute(OpKernelContext* ctx) override; + + private: + XlaPlatformInfo platform_info_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_KERNELS_XLA_LAUNCH_OP_H_ diff --git a/tensorflow/compiler/jit/legacy_flags/BUILD b/tensorflow/compiler/jit/legacy_flags/BUILD index 5b6692f523658749f7ef48f9d7d89e97d4ce8b09..d8fe4026f51d8aa4b027aeedf0795ad30e28d986 100644 --- a/tensorflow/compiler/jit/legacy_flags/BUILD +++ b/tensorflow/compiler/jit/legacy_flags/BUILD @@ -29,9 +29,9 @@ cc_library( ) cc_library( - name = "parallel_check_op_flags", - srcs = ["parallel_check_op_flags.cc"], - hdrs = ["parallel_check_op_flags.h"], + name = "xla_device_flags", + srcs = ["xla_device_flags.cc"], + hdrs = ["xla_device_flags.h"], deps = [ "//tensorflow/compiler/xla/legacy_flags:parse_flags_from_env", @@ -41,9 +41,9 @@ cc_library( ) cc_library( - name = "xla_device_flags", - srcs = ["xla_device_flags.cc"], - hdrs = ["xla_device_flags.h"], + name = "build_xla_ops_pass_flags", + srcs = ["build_xla_ops_pass_flags.cc"], + hdrs = ["build_xla_ops_pass_flags.h"], deps = [ "//tensorflow/compiler/xla/legacy_flags:parse_flags_from_env", diff --git a/tensorflow/compiler/jit/legacy_flags/build_xla_ops_pass_flags.cc b/tensorflow/compiler/jit/legacy_flags/build_xla_ops_pass_flags.cc new file mode 100644 index 0000000000000000000000000000000000000000..58157d2b9800a2e8269533607c2ea688ff4e7766 --- /dev/null +++ b/tensorflow/compiler/jit/legacy_flags/build_xla_ops_pass_flags.cc @@ -0,0 +1,47 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include // NOLINT + +#include "tensorflow/compiler/jit/legacy_flags/build_xla_ops_pass_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace tensorflow { +namespace legacy_flags { +namespace { + +BuildXlaOpsPassFlags* flags; +std::vector* flag_list; +std::once_flag flags_init; + +void AllocateAndParseFlags() { + flags = new BuildXlaOpsPassFlags; + flags->tf_xla_enable_lazy_compilation = false; + flag_list = new std::vector({ + Flag("tf_xla_enable_lazy_compilation", + &flags->tf_xla_enable_lazy_compilation, ""), + }); + xla::legacy_flags::ParseFlagsFromEnv(*flag_list); +} + +} // namespace + +const BuildXlaOpsPassFlags& GetBuildXlaOpsPassFlags() { + std::call_once(flags_init, &AllocateAndParseFlags); + return *flags; +} +} // namespace legacy_flags +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/legacy_flags/build_xla_ops_pass_flags.h b/tensorflow/compiler/jit/legacy_flags/build_xla_ops_pass_flags.h new file mode 100644 index 0000000000000000000000000000000000000000..539314cbf72d38ed973b8a526aa6424b19ef344d --- /dev/null +++ b/tensorflow/compiler/jit/legacy_flags/build_xla_ops_pass_flags.h @@ -0,0 +1,37 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_BUILD_XLA_OPS_PASS_FLAGS_H_ +#define TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_BUILD_XLA_OPS_PASS_FLAGS_H_ + +namespace tensorflow { +namespace legacy_flags { + +// Flags for the build_xla_ops pass. +struct BuildXlaOpsPassFlags { + // Enables lazy compilation for TF/XLA (only when auto-clustering) if true. + // Defaults to false. + bool tf_xla_enable_lazy_compilation; +}; + +// Parses the flags in BuildXlaOpsPassFlags from the TF_XLA_FLAGS environment +// variable and returns a reference to the parsed copy. Parses TF_XLA_FLAGS +// only the first time this routine is called. +const BuildXlaOpsPassFlags& GetBuildXlaOpsPassFlags(); + +} // namespace legacy_flags +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_BUILD_XLA_OPS_PASS_FLAGS_H_ diff --git a/tensorflow/compiler/jit/legacy_flags/parallel_check_op_flags.cc b/tensorflow/compiler/jit/legacy_flags/parallel_check_op_flags.cc deleted file mode 100644 index a61694b49407b923b7c83f35e903ef49a2175f0e..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/jit/legacy_flags/parallel_check_op_flags.cc +++ /dev/null @@ -1,68 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Legacy flags for the XLA bridge's parallel_check_op module. - -#include -#include - -#include "tensorflow/compiler/jit/legacy_flags/parallel_check_op_flags.h" -#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace tensorflow { -namespace legacy_flags { - -// Pointers to the parsed value of the flags and flag descriptors, initialized -// via flags_init. -static ParallelCheckOpFlags* flags; -static std::vector* flag_list; -static std::once_flag flags_init; - -// Allocate *flags. Called via call_once(&flags_init,...). -static void AllocateFlags() { - flags = new ParallelCheckOpFlags; - flags->parallel_check_failfast = true; - flags->parallel_check_atol = "1e-5"; - flags->parallel_check_rtol = "1e-5"; - flag_list = new std::vector({ - Flag("parallel_check_failfast", &flags->parallel_check_failfast, - "Fail immediately on first parallel-check comparison error."), - Flag("parallel_check_atol", &flags->parallel_check_atol, - "Absolute error tolerance for parallel-check comparison."), - Flag("parallel_check_rtol", &flags->parallel_check_rtol, - "Relative error tolerance for parallel-check comparison."), - }); - xla::legacy_flags::ParseFlagsFromEnv(*flag_list); -} - -// Append to *append_to flag definitions associated with the XLA bridge's -// parallel_check_op module. -void AppendParallelCheckOpFlags(std::vector* append_to) { - std::call_once(flags_init, &AllocateFlags); - append_to->insert(append_to->end(), flag_list->begin(), flag_list->end()); -} - -// Return a pointer to the ParallelCheckOpFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -ParallelCheckOpFlags* GetParallelCheckOpFlags() { - std::call_once(flags_init, &AllocateFlags); - return flags; -} - -} // namespace legacy_flags -} // namespace tensorflow diff --git a/tensorflow/compiler/jit/legacy_flags/parallel_check_op_flags.h b/tensorflow/compiler/jit/legacy_flags/parallel_check_op_flags.h deleted file mode 100644 index 156a2a2a71097631e24d154b102cd9b85a990b3a..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/jit/legacy_flags/parallel_check_op_flags.h +++ /dev/null @@ -1,52 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_PARALLEL_CHECK_OP_FLAGS_H_ -#define TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_PARALLEL_CHECK_OP_FLAGS_H_ - -// Legacy flags for the XLA bridge's parallel_check_op module. - -#include - -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace tensorflow { -namespace legacy_flags { - -// Append to *flag_list flag definitions associated with the XLA bridge's -// parallel_check_op module. -void AppendParallelCheckOpFlags(std::vector* flag_list); - -// The values of flags associated with the XLA bridge's -// parallel_check_op module. -typedef struct { - bool parallel_check_failfast; // Fail immediately on first parallel-check - // comparison error. - string parallel_check_atol; // Absolute error tolerance for parallel-check - // comparison. - string parallel_check_rtol; // Relative error tolerance for parallel-check - // comparison. -} ParallelCheckOpFlags; - -// Return a pointer to the ParallelCheckOpFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -ParallelCheckOpFlags* GetParallelCheckOpFlags(); - -} // namespace legacy_flags -} // namespace tensorflow - -#endif // TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_PARALLEL_CHECK_OP_FLAGS_H_ diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 518c39ec15e0a962ee251ca3e630a7c75abf04ff..4f0c370e65159c89c91ea58733f20f852d9acc99 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/jit/deadness_analysis.h" #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/graphcycles/graphcycles.h" @@ -42,8 +43,6 @@ limitations under the License. #include "tensorflow/core/graph/control_flow.h" #include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/lib/gtl/cleanup.h" -#include "tensorflow/core/lib/gtl/flatset.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/public/version.h" @@ -366,10 +365,13 @@ bool IsXlaFusable(const NodeDef& node) { return elementwise_ops->count(node.op()) > 0; } +// Nodes that XLA can compile are put in `candidates`. Nodes put in +// `isolated_nodes` must either be unclustered or be put in trivial single-node +// clusters. Status FindCompilationCandidates( const Graph& graph, FunctionLibraryDefinition* flib_def, Env* env, const std::function& is_compilable_fn, - OrderedNodeSet* candidates) { + OrderedNodeSet* candidates, absl::flat_hash_set* isolated_nodes) { OptimizerOptions opts; std::unique_ptr pflr( new ProcessFunctionLibraryRuntime(nullptr, env, TF_GRAPH_DEF_VERSION, @@ -412,6 +414,8 @@ Status FindCompilationCandidates( DeviceType device_type(""); TF_RETURN_IF_ERROR( DeviceToDeviceType(node->assigned_device_name(), &device_type)); + VLOG(4) << "Device type for " << node->name() << ": " + << device_type.type_string(); if (is_compilable_fn && !is_compilable_fn(node, device_type)) { // is_compilable_fn has already logged the reason if it returned false. @@ -440,19 +444,56 @@ Status FindCompilationCandidates( << node->type_string(); continue; } - if (compile_time_const_nodes[node->id()] && - !registration->requires_compilation) { + if (compile_time_const_nodes[node->id()]) { const OpDef* op_def; TF_RETURN_IF_ERROR( - OpRegistry::Global()->LookUpOpDef(node->type_string(), &op_def)); + graph.op_registry()->LookUpOpDef(node->type_string(), &op_def)); if (op_def->is_stateful()) { - // We need to be able to constant fold the nodes in - // compile_time_const_nodes given constant inputs (required by XLA) and - // therefore can't auto-cluster stateful ops since these can never be - // constant folded. - VLOG(2) << "Rejecting " << node->name() - << ": must-be-constant stateful op"; - continue; + // It is easiest to demonstrate the problem we're trying to solve with + // an example. Say we have this graph: + // + // shape = RandomUniformInt(); + // reshape = Reshape(input, shape) + // + // Both RandomUniformInt and Reshape are compilable by XLA so, absent + // any other reason, we will try to put both shape and reshape in the + // same cluster. However, since XLA only supports statically shaped + // values, it will expect to be able to constant fold `shape` to get a + // static shape for `reshape`. This is a problem because side-effecting + // ops like RandomUniformInt() cannot be constant folded. We fix this + // by putting `shape` and `reshape` in different clusters, which results + // in us recompiling `reshape`'s cluster for every new value of `shape`, + // making `reshape` statically sized within each compilation. We + // simplify the solution even further by disallowing operations like + // `shape` from being part of *any* non-trivial cluster. They're either + // not compiled by XLA altogether or, if assigned to an XLA_* device + // with "must compile" semantics, compiled into a trivial single-op + // cluster. This approach leaves some room for improvement, and we can + // consider implementing a more aggressive data-flow-analysis based + // solution in the future if needed. + // + // One ugly problem we have to contend with: certain sets of ops *have* + // to be in the same cluster because values flowing between them have + // types that can't be live-in or live-out of a cluster. These ops are: + // + // - TensorArray ops operating on the same TensorArray instance. + // - Stack ops operating on the same Stack instance. + // + // To work around this we avoid isolating these specific ops. Because + // of this concession it is unsound to auto-cluster them because then + // we'd create clusters we could not compile (because we can't constant + // fold, say, a TensorArrayRead or a StackPopV2). But we don't + // auto-cluster these operations today so we're good for now. + const XlaResourceOpInfo* op_info = + GetResourceOpInfoForOp(node->type_string()); + bool is_tensor_array_or_stack_op = + op_info && op_info->resource_kind() != XlaResourceKind::kVariable; + if (!is_tensor_array_or_stack_op) { + VLOG(2) << "Isolating " << node->name() + << ": must-be-constant stateful op"; + isolated_nodes->insert(node); + // Keep going and execute all the other checks. + } } } // We don't auto-cluster functional control flow nodes containing resource @@ -617,7 +658,7 @@ Status MarkForCompilationPass::Run( } static string RatioToString(int numerator, int denominator) { - return strings::Printf("%d / %d (%.2f%%)", numerator, denominator, + return absl::StrFormat("%d / %d (%.2f%%)", numerator, denominator, (100.0 * numerator) / denominator); } @@ -626,14 +667,14 @@ static void VLogClusteringSummary(const Graph& g) { return; } - std::map cluster_name_to_size; - std::map> + std::map cluster_name_to_size; + std::map> cluster_name_to_op_histogram; - std::map unclustered_op_histogram; + std::map unclustered_op_histogram; int clustered_node_count = 0; for (Node* n : g.nodes()) { - absl::optional cluster_name = GetXlaClusterForNode(*n); + absl::optional cluster_name = GetXlaClusterForNode(*n); if (cluster_name) { clustered_node_count++; cluster_name_to_size[*cluster_name]++; @@ -650,7 +691,7 @@ static void VLogClusteringSummary(const Graph& g) { << RatioToString(clustered_node_count, g.num_nodes()); for (const auto& cluster_name_size_pair : cluster_name_to_size) { - StringPiece cluster_name = cluster_name_size_pair.first; + absl::string_view cluster_name = cluster_name_size_pair.first; int size = cluster_name_size_pair.second; VLOG(2) << " " << cluster_name << " " << RatioToString(size, g.num_nodes()); @@ -668,6 +709,85 @@ static void VLogClusteringSummary(const Graph& g) { VLOG(3) << " " << pair.first << ": " << pair.second << " instances"; } } + + struct EdgeInfo { + absl::string_view node_name; + absl::optional cluster_name; + + absl::string_view GetClusterName() const { + return cluster_name ? *cluster_name : "[none]"; + } + + std::pair> AsPair() + const { + return {node_name, cluster_name}; + } + + bool operator<(const EdgeInfo& other) const { + return AsPair() < other.AsPair(); + } + }; + + using EdgeInfoMap = std::map>; + + EdgeInfoMap incoming_edge_infos; + EdgeInfoMap outgoing_edge_infos; + + std::set cluster_names_to_print; + + for (const Edge* e : g.edges()) { + const Node* from = e->src(); + absl::optional from_cluster_name = + GetXlaClusterForNode(*from); + + const Node* to = e->dst(); + absl::optional to_cluster_name = + GetXlaClusterForNode(*to); + + if (to_cluster_name == from_cluster_name) { + continue; + } + + if (to_cluster_name) { + incoming_edge_infos[*to_cluster_name] + [EdgeInfo{from->name(), from_cluster_name}]++; + cluster_names_to_print.insert(*to_cluster_name); + } + + if (from_cluster_name) { + outgoing_edge_infos[*from_cluster_name][{to->name(), to_cluster_name}]++; + cluster_names_to_print.insert(*from_cluster_name); + } + } + + VLOG(2) << "*** Inter-Cluster edges:"; + if (cluster_names_to_print.empty()) { + VLOG(2) << " [none]"; + } + + auto print_edge_info_set_for_cluster = [&](absl::string_view cluster_name, + const EdgeInfoMap& edge_info_map, + absl::string_view desc) { + auto it = edge_info_map.find(cluster_name); + if (it != edge_info_map.end()) { + VLOG(2) << " " << it->second.size() << " " << desc << " edges"; + for (const auto& edge_info_count_pair : it->second) { + VLOG(2) << " " << edge_info_count_pair.first.GetClusterName() << " " + << edge_info_count_pair.first.node_name << " # " + << edge_info_count_pair.second; + } + } else { + VLOG(2) << " No " << desc << " edges."; + } + }; + + for (absl::string_view cluster_name : cluster_names_to_print) { + VLOG(2) << " ** Cluster " << cluster_name; + print_edge_info_set_for_cluster(cluster_name, incoming_edge_infos, + "incoming"); + print_edge_info_set_for_cluster(cluster_name, outgoing_edge_infos, + "outgoing"); + } } // Is 'node' an operator that consumes only the shape of its input, not the @@ -729,11 +849,12 @@ Status MarkForCompilationPass::RunImpl( Graph* graph = options.graph->get(); OrderedNodeSet compilation_candidates; + absl::flat_hash_set isolated_nodes; TF_RETURN_IF_ERROR(FindCompilationCandidates( *graph, options.flib_def, (options.session_options != nullptr) ? options.session_options->env : Env::Default(), - is_compilable_fn, &compilation_candidates)); + is_compilable_fn, &compilation_candidates, &isolated_nodes)); if (compilation_candidates.empty()) { VLOG(2) << "No compilable candidates"; @@ -778,6 +899,11 @@ Status MarkForCompilationPass::RunImpl( "Found control flow node in clustering worklist: ", node_from->type_string()); } + + if (isolated_nodes.count(node_from)) { + continue; + } + string from_scope; string to_scope; for (int to : cycles.Successors(from)) { @@ -795,6 +921,9 @@ Status MarkForCompilationPass::RunImpl( node_to->assigned_device_name()) { continue; } + if (isolated_nodes.count(node_to)) { + continue; + } // Look for an _XlaScope on both nodes. If both nodes have a // scope and the scopes do not match, do not cluster along this // edge. This restriction is overridden if the global_jit_level is ON. If @@ -853,6 +982,11 @@ Status MarkForCompilationPass::RunImpl( // Names for each cluster. std::unordered_map cluster_names; + if (flags->tf_xla_clustering_debug) { + dump_graph::DumpGraphToFile("before_mark_for_compilation", **options.graph, + options.flib_def); + } + // Mark clusters for compilation that: // * are placed on a device that requires compilation (an XlaDevice), // * are explicitly marked for compilation (_XlaCompile=true), or @@ -890,7 +1024,7 @@ Status MarkForCompilationPass::RunImpl( string& name = cluster_names[cluster]; if (name.empty()) { - name = strings::StrCat("cluster_", cluster_sequence_num++); + name = absl::StrCat("cluster_", cluster_sequence_num++); } n->AddAttr(kXlaClusterAttr, name); VLOG(3) << "Assigning node " << n->name() << " to cluster " << name; diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc index 807ab51fd3c133b95915ea88e0bf99dbb8661452..2a80c745e3fcebf97bcccb03551feb3d6fb9f831 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h" +#include "absl/container/flat_hash_map.h" +#include "absl/memory/memory.h" #include "absl/strings/match.h" #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/ops/array_ops.h" @@ -60,10 +62,10 @@ std::unordered_map GetClusters(const Graph& graph) { return ids; } -gtl::FlatMap> GetClusterSets( +absl::flat_hash_map> GetClusterSets( const Graph& g, std::vector* cluster_names = nullptr) { CHECK(cluster_names == nullptr || cluster_names->empty()); - gtl::FlatMap> cluster_sets; + absl::flat_hash_map> cluster_sets; for (const auto& p : GetClusters(g)) { cluster_sets[p.second].push_back(p.first); } @@ -565,7 +567,7 @@ TEST(XlaCompilationTest, ResourcesClusteringAllowed) { std::unique_ptr graph(new Graph(OpRegistry::Global())); TF_EXPECT_OK(root.ToGraph(graph.get())); TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); - gtl::FlatMap> cluster_sets = + absl::flat_hash_map> cluster_sets = GetClusterSets(*graph); ASSERT_EQ(cluster_sets.size(), 1); std::vector expected_clustered_nodes = {"AssignmentW", "ReadR", @@ -585,7 +587,7 @@ TEST(XlaCompilationTest, ResourcesClusteringDisallowed) { std::unique_ptr graph(new Graph(OpRegistry::Global())); TF_EXPECT_OK(root.ToGraph(graph.get())); TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); - gtl::FlatMap> cluster_sets = + absl::flat_hash_map> cluster_sets = GetClusterSets(*graph); ASSERT_EQ(cluster_sets.size(), 1); std::vector expected_clustered_nodes = {"AssignmentW", @@ -615,7 +617,7 @@ TEST(XlaCompilationTest, ChainOfOps) { TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); std::vector cluster_names; - gtl::FlatMap> cluster_sets = + absl::flat_hash_map> cluster_sets = GetClusterSets(*graph, &cluster_names); ASSERT_EQ(cluster_sets.size(), 2); @@ -633,7 +635,7 @@ TEST(XlaCompilationTest, IllegalCycle_UsefulErrorMessage) { std::unique_ptr graph(new Graph(OpRegistry::Global())); Scope root = Scope::NewRootScope().ExitOnError(); { - auto BuildNoopNode = [](StringPiece name, Graph* graph) { + auto BuildNoopNode = [](absl::string_view name, Graph* graph) { NodeDefBuilder builder(name, "NoOp"); NodeDef def; TF_CHECK_OK(builder.Finalize(&def)); @@ -847,5 +849,117 @@ TEST(XlaCompilationTest, RandomShape) { EXPECT_EQ(clusters["shape"], ""); } +TEST(XlaCompilationTest, RandomShapeWithFunc) { + Scope root = Scope::DisabledShapeInferenceScope().ExitOnError(); + + FunctionDefLibrary flib_def; + FunctionDef func = FunctionDefHelper::Create( + /*function_name=*/"Stateful_func", /*in_def=*/{}, + /*out_def=*/{"out: int32"}, + /*attr_def*/ + {}, /*node_def=*/ + {FunctionDefHelper::Const("shape_shape", 2), + FunctionDefHelper::Const("minval", 1), + FunctionDefHelper::Const("maxval", 20), + {{"shape"}, + "RandomUniformInt", + {"shape_shape:output:0", "minval:output:0", "maxval:output:0"}, + {{"Tout", DataType::DT_INT32}, {"T", DataType::DT_INT32}}}}, + /*ret_def=*/{{"out", "shape:output:0"}}); + + func.mutable_signature()->set_is_stateful(true); + *flib_def.add_function() = std::move(func); + TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def)); + NodeDef call_node; + call_node.set_name("fn_call"); + call_node.set_op("Stateful_func"); + Status status; + Node* call = root.graph()->AddNode(call_node, &status); + TF_ASSERT_OK(status); + + Output shape = Output(call, 0); + Output reshape_input = + ops::Placeholder(root.WithOpName("reshape_input"), DT_FLOAT, + ops::Placeholder::Shape(TensorShape({500, 500}))); + Output reshape = + ops::Reshape(root.WithOpName("reshape"), reshape_input, shape); + + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(root.ToGraph(graph.get())); + auto fld = absl::make_unique(OpRegistry::Global(), + flib_def); + TF_ASSERT_OK( + MarkForCompilationPassTestHelper::MarkForCompilation(&graph, fld.get())); + + std::unordered_map clusters = GetClusters(*graph); + EXPECT_EQ(clusters["fn_call"], ""); +} + +TEST(XlaCompilationTest, RandomShapeOnXlaDevice) { + absl::string_view xla_gpu_device = + "/job:worker/replica:0/task:0/device:XLA_GPU:0"; + + Scope root = Scope::NewRootScope().ExitOnError(); + Output shape_shape = + ops::Const(root.WithOpName("test/shape_shape"), {2}, {1}); + Output shape = + ops::RandomUniformInt(root.WithOpName("test/shape_rng"), shape_shape, + ops::Const(root.WithOpName("test/minval"), 1), + ops::Const(root.WithOpName("test/maxval"), 20)); + Output reshape_input = + ops::Placeholder(root.WithOpName("test/reshape_input"), DT_FLOAT, + ops::Placeholder::Shape(TensorShape({500, 500}))); + Output reshape = + ops::Reshape(root.WithOpName("test/reshape"), reshape_input, shape); + + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(root.ToGraph(graph.get())); + + for (Node* n : graph->nodes()) { + if (absl::StartsWith(n->name(), /*prefix=*/"test/")) { + n->set_assigned_device_name(string(xla_gpu_device)); + } + } + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); + + std::unordered_map clusters = GetClusters(*graph); + EXPECT_NE(clusters["test/shape_rng"], ""); + EXPECT_NE(clusters["test/reshape"], ""); + EXPECT_NE(clusters["test/shape_rng"], clusters["test/reshape"]); +} + +TEST(XlaCompilationTest, TensorArrayShapeOnXlaDevice) { + absl::string_view xla_gpu_device = + "/job:worker/replica:0/task:0/device:XLA_GPU:0"; + Scope root = Scope::NewRootScope().ExitOnError(); + ops::TensorArray tensor_array(root.WithOpName("test/tensor_array"), 1, + DT_INT32); + Output zero = ops::Const(root.WithOpName("test/zero"), 0); + ops::TensorArrayWrite tensor_array_write( + root.WithOpName("test/write"), tensor_array.handle, zero, + ops::Const(root.WithOpName("test/forty_two"), 42.0f), tensor_array.flow); + Output tensor_array_read = + ops::TensorArrayRead(root.WithOpName("test/read"), tensor_array.handle, + zero, tensor_array_write.flow_out, DT_INT32); + Output reshape = + ops::Reshape(root.WithOpName("test/reshape"), + ops::Placeholder(root.WithOpName("placeholder"), DT_FLOAT), + tensor_array_read); + + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(root.ToGraph(graph.get())); + + for (Node* n : graph->nodes()) { + if (absl::StartsWith(n->name(), /*prefix=*/"test/")) { + n->set_assigned_device_name(string(xla_gpu_device)); + } + } + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); + + std::unordered_map clusters = GetClusters(*graph); + EXPECT_NE(clusters["test/read"], ""); + EXPECT_EQ(clusters["test/read"], clusters["test/reshape"]); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc index 65669877f732bad9e145da36a3aedeba611a0fe5..d56d0f8ccfcdab40003be38059228cb255921b64 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc @@ -14,18 +14,35 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h" +#include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/public/session_options.h" namespace tensorflow { /*static*/ Status MarkForCompilationPassTestHelper::MarkForCompilation( std::unique_ptr* graph, FunctionLibraryDefinition* flib_def, SessionOptions* session_options) { - // Assign all nodes to the CPU device. + // Assign all unassigned nodes to the CPU device. static const char* kCpuDevice = "/job:localhost/replica:0/task:0/cpu:0"; for (Node* n : (*graph)->nodes()) { - n->set_assigned_device_name(kCpuDevice); + if (n->assigned_device_name().empty()) { + n->set_assigned_device_name(kCpuDevice); + } } + // Call AddDevices to register the XLA devices. + // + // It may be worth refactoring out XlaOpRegistry::RegisterCompilationDevice to + // make this more direct, but probably not worth it solely for this test. + std::vector devices; + TF_RETURN_IF_ERROR(DeviceFactory::AddDevices(*session_options, "", &devices)); + + auto delete_devices = gtl::MakeCleanup([&] { + for (Device* d : devices) { + delete d; + } + }); + GraphOptimizationPassOptions opt_options; opt_options.graph = graph; opt_options.session_options = session_options; diff --git a/tensorflow/compiler/jit/node_matchers.cc b/tensorflow/compiler/jit/node_matchers.cc new file mode 100644 index 0000000000000000000000000000000000000000..a09a6eb1553cb4bcf5587a7602097a40b64cfcdf --- /dev/null +++ b/tensorflow/compiler/jit/node_matchers.cc @@ -0,0 +1,512 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/node_matchers.h" + +#include +#include "absl/algorithm/container.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/str_replace.h" +#include "absl/strings/str_split.h" +#include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/tensor.pb.h" + +namespace tensorflow { +namespace testing { +namespace matchers { +namespace { + +using impl::NodeMatcherProperties; +using impl::OutEdge; + +string IndentAllButFirstLine(absl::string_view text) { + std::vector lines = absl::StrSplit(text, '\n'); + for (int i = 1; i < lines.size(); i++) { + lines[i].insert(0, " "); + } + return absl::StrJoin(lines, "\n"); +} + +template +bool CompareTensor(const Tensor& actual, const Tensor& expected, + ::testing::MatchResultListener* listener) { + if (actual.NumElements() != expected.NumElements()) { + if (listener->IsInterested()) { + *listener << "\nwas looking for tensor with " << expected.NumElements() + << " elements, found tensor with " << actual.NumElements() + << " elements"; + return false; + } + } + + for (int64 i = 0, e = actual.NumElements(); i < e; i++) { + if (actual.flat()(i) != expected.flat()(i)) { + *listener << "\nmismatch in constant tensor at index " << i + << " expected = " << expected.flat()(i) + << " actual = " << actual.flat()(i); + return false; + } + } + + return true; +} + +bool MatchAndExplainTensor(const Tensor& tensor, const Tensor& expected_tensor, + ::testing::MatchResultListener* listener) { + if (tensor.dtype() != expected_tensor.dtype()) { + if (listener->IsInterested()) { + *listener << "\nexpected tensor of type " + << DataType_Name(expected_tensor.dtype()) + << " but found one of type " << DataType_Name(tensor.dtype()); + return false; + } + } + + switch (tensor.dtype()) { + case DT_FLOAT: + return CompareTensor(tensor, expected_tensor, listener); + case DT_DOUBLE: + return CompareTensor(tensor, expected_tensor, listener); + case DT_INT8: + return CompareTensor(tensor, expected_tensor, listener); + case DT_INT16: + return CompareTensor(tensor, expected_tensor, listener); + case DT_INT32: + return CompareTensor(tensor, expected_tensor, listener); + case DT_INT64: + return CompareTensor(tensor, expected_tensor, listener); + case DT_UINT8: + return CompareTensor(tensor, expected_tensor, listener); + case DT_UINT16: + return CompareTensor(tensor, expected_tensor, listener); + case DT_UINT32: + return CompareTensor(tensor, expected_tensor, listener); + case DT_UINT64: + return CompareTensor(tensor, expected_tensor, listener); + default: + LOG(FATAL) << "Unsupported dtype " // Crash ok: testonly. + << DataType_Name(tensor.dtype()); + } +} + +struct NodeMatcher : public ::testing::MatcherInterface { + bool MatchAndExplain( + const Node* node, + ::testing::MatchResultListener* listener) const override { + if (op && node->type_string() != *op) { + if (listener->IsInterested()) { + *listener << "\nexpected op " << *op << " but found " + << node->type_string(); + } + return false; + } + + if (assigned_device && node->assigned_device_name() != *assigned_device) { + if (listener->IsInterested()) { + *listener << "\nexpected assigned_device " << *assigned_device + << " but found \"" << node->assigned_device_name() << "\""; + } + return false; + } + + if (name && node->name() != *name) { + if (listener->IsInterested()) { + *listener << "\nexpected name " << *name << " but found " + << node->name(); + } + return false; + } + + if (constant_value) { + const TensorProto* proto = nullptr; + if (!GetNodeAttr(node->def(), "value", &proto).ok()) { + if (listener->IsInterested()) { + *listener << "\ncould not find \"value\" attribute in node"; + } + return false; + } + + Tensor tensor(proto->dtype()); + if (!tensor.FromProto(*proto)) { + if (listener->IsInterested()) { + *listener << "\ncould not convert TensorProto in \"value\" attribute " + "to Tensor"; + } + return false; + } + + if (!MatchAndExplainTensor(/*tensor=*/tensor, + /*expected_tensor=*/*constant_value, + listener)) { + return false; + } + } + + if (input_matchers) { + if (input_matchers->size() != node->num_inputs()) { + if (listener->IsInterested()) { + *listener << "\nexpected " << input_matchers->size() + << " inputs but node has " << node->num_inputs(); + } + return false; + } + + for (int input_idx = 0, e = input_matchers->size(); input_idx < e; + input_idx++) { + if (!MatchAndExplainInput(node, input_idx, listener)) { + return false; + } + } + } + + std::vector control_deps; + for (const Edge* e : node->in_edges()) { + if (e->IsControlEdge()) { + control_deps.push_back(e->src()); + } + } + + ::testing::StringMatchResultListener inner_listener; + if (control_dep_set && + !control_dep_set->MatchAndExplain(control_deps, &inner_listener)) { + if (listener->IsInterested()) { + string explanation = inner_listener.str(); + if (!explanation.empty()) { + explanation = absl::StrCat(", ", explanation, ","); + } + *listener << "ctrl_deps" << explanation << " does not match expected: "; + control_dep_set->DescribeTo(listener->stream()); + } + return false; + } + + const AttrValueMap attr_value_map = node->def().attr(); + for (const auto& attr_kv_pair : attrs) { + auto it = attr_value_map.find(attr_kv_pair.first); + if (it == attr_value_map.end()) { + if (listener->IsInterested()) { + *listener << "did not find attribute named \"" << attr_kv_pair.first + << "\" in node"; + } + return false; + } + if (!AreAttrValuesEqual(it->second, attr_kv_pair.second)) { + if (listener->IsInterested()) { + *listener << "attribute named " << attr_kv_pair.first + << " does not match value; expected: \"" + << SummarizeAttrValue(attr_kv_pair.second) + << "\", found: \"" << SummarizeAttrValue(it->second) + << "\""; + } + return false; + } + } + + return true; + } + + void DescribeTo(::std::ostream* os) const override { + std::vector predicates; + + if (name) { + predicates.push_back(absl::StrCat("name: ", *name)); + } + + if (op) { + predicates.push_back(absl::StrCat("op: ", *op)); + } + + if (assigned_device) { + predicates.push_back(absl::StrCat("assigned device: ", *assigned_device)); + } + + bool printed_something = !predicates.empty(); + + *os << absl::StrJoin(predicates, ", "); + + if (constant_value) { + printed_something = true; + *os << "constant value: " << constant_value->DebugString(); + } + + if (input_matchers) { + if (!input_matchers->empty()) { + printed_something = true; + *os << " with " << (input_matchers->size() == 1 ? "only " : "") + << "input" << (input_matchers->size() == 1 ? "" : "s") << " "; + } + + if (input_matchers->size() == 1) { + ::std::stringstream ss; + input_matchers->front().DescribeTo(&ss); + printed_something = true; + *os << "matching " << ss.str(); + } else { + int edge_idx = 0; + for (const ::testing::Matcher& matcher : (*input_matchers)) { + *os << "\n [" << edge_idx << "] matching ("; + ::std::stringstream ss; + matcher.DescribeTo(&ss); + printed_something = true; + *os << IndentAllButFirstLine(ss.str()); + *os << ")"; + edge_idx++; + } + } + } + + if (control_dep_set) { + printed_something = true; + *os << " and control deps "; + control_dep_set->DescribeTo(os); + } + + if (!attrs.empty()) { + printed_something = true; + std::vector attrs_str; + absl::c_transform(attrs, std::back_inserter(attrs_str), + [](const std::pair& attr_kv_pair) { + return absl::StrCat( + attr_kv_pair.first, "->", + SummarizeAttrValue(attr_kv_pair.second)); + }); + *os << " and attr values matching [" << absl::StrJoin(attrs_str, ", ") + << "]"; + } + + if (!printed_something) { + *os << "is any node"; + } + } + + bool MatchAndExplainInput(const Node* node, int input_idx, + ::testing::MatchResultListener* listener) const { + const Edge* edge; + if (!node->input_edge(input_idx, &edge).ok()) { + if (listener->IsInterested()) { + *listener << "\ncould not find incoming edge for input " << input_idx; + } + return false; + } + + ::testing::StringMatchResultListener inner_listener; + OutEdge input = {edge->src(), edge->src_output()}; + if ((*input_matchers)[input_idx].MatchAndExplain(input, &inner_listener)) { + return true; + } + + if (listener->IsInterested()) { + *listener << "\ninput " << input_idx << " does not match expected:\n"; + (*input_matchers)[input_idx].DescribeTo(listener->stream()); + string explanation = inner_listener.str(); + if (!explanation.empty()) { + *listener << ", " << explanation; + } + } + return false; + } + + absl::optional op; + absl::optional name; + absl::optional assigned_device; + absl::optional constant_value; + absl::optional>> input_matchers; + absl::optional<::testing::Matcher>> + control_dep_set; + std::map attrs; +}; + +// Matches a dst and dst_output on an input edge. Today we only use this with +// dst_output=0 but we will eventually need to support multi-output operations. +class OutEdgeMatcher : public ::testing::MatcherInterface { + public: + OutEdgeMatcher(::testing::Matcher src_matcher, int src_oidx) + : src_matcher_(std::move(src_matcher)), src_oidx_(src_oidx) {} + + bool MatchAndExplain( + OutEdge out_edge, + ::testing::MatchResultListener* listener) const override { + ::testing::StringMatchResultListener inner_listener; + if (!src_matcher_.MatchAndExplain(out_edge.first, &inner_listener)) { + if (listener->IsInterested()) { + *listener << "\nsource does not match expected "; + src_matcher_.DescribeTo(listener->stream()); + string explanation = inner_listener.str(); + if (!explanation.empty()) { + *listener << "\n\t" << explanation; + } + } + return false; + } + if (out_edge.second != src_oidx_) { + if (listener->IsInterested()) { + *listener << "\nexpected output slot to be " << src_oidx_ + << " but found " << out_edge.second; + } + return false; + } + + return true; + } + + void DescribeTo(::std::ostream* os) const override { + if (src_oidx_) { + *os << "output slot: " << src_oidx_ << ", source: ("; + } + + src_matcher_.DescribeTo(os); + + if (src_oidx_) { + *os << ")"; + } + } + + private: + ::testing::Matcher src_matcher_; + int src_oidx_; +}; +} // namespace + +::testing::Matcher impl::NodeWith( + absl::Span props) { + NodeMatcher* matcher = new NodeMatcher(); + for (const NodeMatcherProperties& prop : props) { + if (prop.name()) { + DCHECK(!matcher->name); + matcher->name = prop.name(); + } + + if (prop.op()) { + DCHECK(!matcher->op); + matcher->op = prop.op(); + } + + if (prop.constant_value()) { + DCHECK(!matcher->constant_value); + matcher->constant_value = prop.constant_value(); + } + + if (prop.assigned_device()) { + DCHECK(!matcher->assigned_device); + matcher->assigned_device = prop.assigned_device(); + } + + if (prop.inputs()) { + DCHECK(!matcher->input_matchers); + matcher->input_matchers = *prop.inputs(); + } + + if (prop.control_deps()) { + DCHECK(!matcher->control_dep_set); + matcher->control_dep_set = + ::testing::UnorderedElementsAreArray(*prop.control_deps()); + } + + if (prop.attr()) { + auto insert_result = matcher->attrs.insert(*prop.attr()); + DCHECK(insert_result.second); + } + } + + return ::testing::MakeMatcher(matcher); +} + +impl::NodeMatcherProperties Name(string name) { + impl::NodeMatcherProperties props; + props.set_name(std::move(name)); + return props; +} + +// Matches a node with op `op`. +impl::NodeMatcherProperties Op(string op) { + impl::NodeMatcherProperties props; + props.set_op(std::move(op)); + return props; +} + +// Matches a node with assigned device `assigned_device`. +impl::NodeMatcherProperties AssignedDevice(string assigned_device) { + impl::NodeMatcherProperties props; + props.set_assigned_device(std::move(assigned_device)); + return props; +} + +impl::NodeMatcherProperties impl::Inputs( + absl::Span> inputs) { + std::vector<::testing::Matcher> inputs_vector; + absl::c_copy(inputs, std::back_inserter(inputs_vector)); + + impl::NodeMatcherProperties props; + props.set_inputs(std::move(inputs_vector)); + return props; +} + +impl::NodeMatcherProperties impl::CtrlDeps( + absl::Span> control_deps) { + std::vector<::testing::Matcher> control_deps_vector; + absl::c_copy(control_deps, std::back_inserter(control_deps_vector)); + + impl::NodeMatcherProperties props; + props.set_control_deps(std::move(control_deps_vector)); + return props; +} + +std::pair impl::AttrLiteralHelper( + const std::pair& bool_attr) { + AttrValue attr_value; + attr_value.set_b(bool_attr.second); + return {bool_attr.first, attr_value}; +} + +impl::NodeMatcherProperties impl::Attr(std::pair attr) { + impl::NodeMatcherProperties props; + props.set_attr(std::move(attr)); + return props; +} + +NodeMatcherProperties ConstantValue( + const ::tensorflow::Input::Initializer& val) { + TF_CHECK_OK(val.status); + NodeMatcherProperties props; + props.set_constant_value(val.tensor); + return props; +} + +::testing::Matcher Const( + const ::tensorflow::Input::Initializer& val) { + return NodeWith(ConstantValue(val)); +} +::testing::Matcher Out( + int oidx, ::testing::Matcher node_matcher) { + return ::testing::MakeMatcher(new OutEdgeMatcher(node_matcher, oidx)); +} +} // namespace matchers + +Node* FindNodeByName(Graph* g, absl::string_view name) { + for (Node* n : g->nodes()) { + if (n->name() == name) { + return n; + } + } + + return nullptr; +} +} // namespace testing + +void PrintTo(const Node* n, ::std::ostream* os) { *os << SummarizeNode(*n); } +void PrintTo(Node* n, ::std::ostream* os) { *os << SummarizeNode(*n); } +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/node_matchers.h b/tensorflow/compiler/jit/node_matchers.h new file mode 100644 index 0000000000000000000000000000000000000000..35c2f5fd7b533d0e8716dc6c70c21afe9a32c9c8 --- /dev/null +++ b/tensorflow/compiler/jit/node_matchers.h @@ -0,0 +1,239 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Provides a set of matchers for tensorflow nodes. +// +// Example usage: +// +// tensorflow::Node* node = ...; +// EXPECT_THAT(node, NodeWith(Name("name"), Op("op"), +// Inputs(Out(3, NodeWith(Name("input")))))) +// +// Matchable node properties (the expressions that go inside NodeWith(...)) +// are: +// +// - Name(string): matches the node name exactly. We will probably need to +// have this take a string matcher soon in the future. +// +// - Op(string): matches the op exactly. +// +// - AssignedDevice(string): matches the assigned device exactly. +// +// - Inputs(): matches the list of non-control inputs to the node +// exactly (i.e. does not match a suffix or a prefix) where each element +// matches an output of a node (see Out(idx, node) below). +// +// - CtrlDeps(): matches the list of control dependences on the +// node exactly but in any order. +// +// - ConstantValue(tensorflow::Input::Initializer init): matches a Const node +// with the constant value `init`. Implies Op("Const"). +// +// - Attr(name, value): Matches a single attribute with name `name` and value +// `value`. Right now only boolean values are supported. +// +// Overlapping node properties may not be repeated in a single NodeWith(...) +// matcher. E.g. NodeWith(Op("Foo"), Op("Bar")) will CHECK-fail. Since +// ConstantValue implies Op("Const"), a single NodeWith matcher can't have both +// ConstantValue(...) and Op(...). Multiple Attr() values can be combined as +// long as the attribute names are different. +// +// Out(idx, node) matches the `idx`'th output of a node that matches `node`. + +#ifndef TENSORFLOW_COMPILER_JIT_NODE_MATCHERS_H_ +#define TENSORFLOW_COMPILER_JIT_NODE_MATCHERS_H_ + +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "tensorflow/cc/framework/ops.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/core/graph/graph.h" + +namespace tensorflow { +namespace testing { +namespace matchers { + +namespace impl { + +using OutEdge = std::pair; + +// ----------------------------------------------------------------------------- +// Implementation details. + +// Properties that we match on for a particular Node. If a particular property +// is nullopt then any value for it is allowed. +class NodeMatcherProperties { + public: + using NodeSeqMatcher = std::vector<::testing::Matcher>; + using InputSeqMatcher = std::vector<::testing::Matcher>; + using AttrKeyValuePair = std::pair; + + const absl::optional& name() const { return name_; } + const absl::optional& op() const { return op_; } + const absl::optional& assigned_device() const { + return assigned_device_; + } + const absl::optional& constant_value() const { + return constant_value_; + } + const absl::optional& inputs() const { + return input_matchers_; + } + const absl::optional& control_deps() const { + return control_deps_; + } + const absl::optional& attr() const { return attr_; } + + void set_name(string name) { + DCHECK(IsEmpty()); + name_ = std::move(name); + } + + void set_op(string op) { + DCHECK(IsEmpty()); + op_ = std::move(op); + } + + void set_assigned_device(string assigned_device) { + DCHECK(IsEmpty()); + assigned_device_ = std::move(assigned_device); + } + + void set_constant_value(Tensor constant_value) { + DCHECK(IsEmpty()); + constant_value_ = std::move(constant_value); + op_ = "Const"; + } + + void set_inputs(InputSeqMatcher inputs) { + DCHECK(IsEmpty()); + input_matchers_ = std::move(inputs); + } + + void set_control_deps(NodeSeqMatcher control_deps) { + DCHECK(IsEmpty()); + control_deps_ = std::move(control_deps); + } + + void set_attr(AttrKeyValuePair attr) { + DCHECK(IsEmpty()); + attr_ = std::move(attr); + } + + bool IsEmpty() const { + return !name().has_value() && !op().has_value() && !inputs().has_value() && + !control_deps().has_value() && !attr().has_value(); + } + + private: + absl::optional name_; + absl::optional op_; + absl::optional assigned_device_; + absl::optional constant_value_; + absl::optional input_matchers_; + absl::optional control_deps_; + absl::optional attr_; +}; + +::testing::Matcher NodeWith( + absl::Span props); + +impl::NodeMatcherProperties Inputs( + absl::Span> inputs); + +impl::NodeMatcherProperties CtrlDeps( + absl::Span> control_deps); + +impl::NodeMatcherProperties Attr(std::pair attrs); + +std::pair AttrLiteralHelper( + const std::pair& bool_attr); +} // namespace impl + +// ----------------------------------------------------------------------------- +// Public interface. + +// Matches a node with name `name`. +impl::NodeMatcherProperties Name(string name); + +// Matches a node with op `op`. +impl::NodeMatcherProperties Op(string op); + +// Matches a node with assigned device `assigned_device`. +impl::NodeMatcherProperties AssignedDevice(string assigned_device); + +// Matches a node with a boolean typed attrbute named `name` and with value +// `value`. +template +impl::NodeMatcherProperties Attr(const string& name, ValueTy value) { + return impl::Attr({impl::AttrLiteralHelper({name, value})}); +} + +// Matches a node with inputs `inputs`. +// +// `inputs` are ordered; `inputs`[i] must match input i. +template +impl::NodeMatcherProperties Inputs(Ts... inputs) { + return impl::Inputs({inputs...}); +} + +// Matches the `idx`'th output of a node that matches `node`. +::testing::Matcher Out(int oidx, + ::testing::Matcher node); + +// Matches the first output of a node that matches `node`. +::testing::Matcher Out(::testing::Matcher node) { + return Out(0, node); +} + +// Matches a node with control dependences `control_deps`. +// +// `control_deps` are unordered and will match the control deps of a node in any +// order. +template +impl::NodeMatcherProperties CtrlDeps(Ts... control_deps) { + return impl::CtrlDeps({control_deps...}); +} + +// Matches a constant node with value `val`. +impl::NodeMatcherProperties ConstantValue( + const ::tensorflow::Input::Initializer& val); + +// The main gmock matcher. See file comment for example usage. +template +::testing::Matcher NodeWith(Ts... args) { + std::array array = {args...}; + return impl::NodeWith(array); +} + +::testing::Matcher Const( + const ::tensorflow::Input::Initializer& val); +} // namespace matchers + +// If `g` has a node named `name` returns it, otherwise returns null. +Node* FindNodeByName(Graph* g, absl::string_view name); +} // namespace testing + +void PrintTo(const Node* n, ::std::ostream* os); +void PrintTo(Node* n, ::std::ostream* os); +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_NODE_MATCHERS_H_ diff --git a/tensorflow/compiler/jit/node_matchers_test.cc b/tensorflow/compiler/jit/node_matchers_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..c3f0dfece85573d71dbfa21eba5af70b674fe71e --- /dev/null +++ b/tensorflow/compiler/jit/node_matchers_test.cc @@ -0,0 +1,214 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/node_matchers.h" + +#include "tensorflow/cc/framework/ops.h" +#include "tensorflow/cc/ops/array_ops.h" +#include "tensorflow/cc/ops/const_op.h" +#include "tensorflow/cc/ops/control_flow_ops.h" +#include "tensorflow/cc/ops/control_flow_ops_internal.h" +#include "tensorflow/cc/ops/math_ops.h" + +namespace tensorflow { +namespace testing { +namespace { + +using ::testing::_; + +using testing::matchers::AssignedDevice; +using testing::matchers::Attr; +using testing::matchers::ConstantValue; +using testing::matchers::CtrlDeps; +using testing::matchers::Inputs; +using testing::matchers::Name; +using testing::matchers::NodeWith; +using testing::matchers::Op; +using testing::matchers::Out; + +template +string Explain(const T& t, const M& m) { + ::testing::StringMatchResultListener listener; + EXPECT_THAT(t, ::testing::Not(m)); // For the error message. + EXPECT_FALSE(m.MatchAndExplain(t, &listener)); + return listener.str(); +} + +TEST(NodeMatchers, CheckAgainstConstant) { + Scope root = Scope::NewRootScope().ExitOnError(); + Output placeholder = + ops::Placeholder(root.WithOpName("placeholder"), DT_FLOAT); + + EXPECT_THAT(placeholder.node(), NodeWith(Op("Placeholder"))); + EXPECT_THAT(placeholder.node(), NodeWith(Name("placeholder"))); + EXPECT_THAT(placeholder.node(), + NodeWith(Op("Placeholder"), Name("placeholder"))); + EXPECT_THAT(placeholder.node(), + NodeWith(Name("placeholder"), Op("Placeholder"))); + EXPECT_THAT(placeholder.node(), NodeWith(Inputs())); + EXPECT_THAT(placeholder.node(), + NodeWith(Op("Placeholder"), Name("placeholder"), Inputs())); + + EXPECT_EQ(Explain(placeholder.node(), NodeWith(Op("Add"))), + "\nexpected op Add but found Placeholder"); + EXPECT_EQ(Explain(placeholder.node(), NodeWith(Name("add"))), + "\nexpected name add but found placeholder"); + EXPECT_EQ(Explain(placeholder.node(), NodeWith(Inputs(Out(NodeWith())))), + "\nexpected 1 inputs but node has 0"); +} + +TEST(NodeMatchers, CheckAgainstBinary) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Output placeholder_a = + ops::Placeholder(root.WithOpName("placeholder_a"), DT_FLOAT); + Output placeholder_b = + ops::Placeholder(root.WithOpName("placeholder_b"), DT_FLOAT); + Output add = ops::Add(root.WithOpName("add"), placeholder_a, placeholder_b); + + EXPECT_THAT(add.node(), + NodeWith(Op("Add"), Name("add"), + Inputs(Out(NodeWith(Name("placeholder_a"))), + Out(NodeWith(Name("placeholder_b")))))); + + EXPECT_EQ(Explain(add.node(), NodeWith(Inputs())), + "\nexpected 0 inputs but node has 2"); + EXPECT_EQ( + Explain(add.node(), NodeWith(Inputs(Out(NodeWith(Name("blah"))), _))), + "\ninput 0 does not match expected:\nname: blah, \nsource does not match " + "expected name: blah\n\t\nexpected name blah but found placeholder_a"); + EXPECT_EQ( + Explain(add.node(), NodeWith(Inputs(_, Out(NodeWith(Name("blah")))))), + "\ninput 1 does not match expected:\nname: blah, \nsource does not match " + "expected name: blah\n\t\nexpected name blah but found placeholder_b"); +} + +TEST(NodeMatchers, CheckControlDependence) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Output placeholder_a = + ops::Placeholder(root.WithOpName("placeholder_a"), DT_FLOAT); + Output placeholder_b = + ops::Placeholder(root.WithOpName("placeholder_b"), DT_FLOAT); + Output placeholder_c = + ops::Placeholder(root.WithOpName("placeholder_c"), DT_FLOAT); + Output placeholder_d = + ops::Placeholder(root.WithOpName("placeholder_d"), DT_FLOAT); + + root.graph()->AddControlEdge(placeholder_a.node(), placeholder_c.node()); + root.graph()->AddControlEdge(placeholder_b.node(), placeholder_c.node()); + + EXPECT_THAT(placeholder_c.node(), + NodeWith(Name("placeholder_c"), + CtrlDeps(NodeWith(Name("placeholder_a")), + NodeWith(Name("placeholder_b"))))); + EXPECT_THAT(placeholder_d.node(), + NodeWith(Name("placeholder_d"), CtrlDeps())); + + EXPECT_EQ( + Explain(placeholder_c.node(), NodeWith(CtrlDeps())), + "ctrl_deps, which has 2 elements, does not match expected: is empty"); + EXPECT_EQ(Explain(placeholder_d.node(), NodeWith(CtrlDeps(NodeWith()))), + "ctrl_deps does not match expected: has 1 element and that element " + "is any node"); +} + +TEST(NodeMatchers, ConstVaulue) { + Scope root = Scope::NewRootScope().ExitOnError(); + Output placeholder = + ops::Placeholder(root.WithOpName("placeholder"), DT_FLOAT); + Output const_0d = ops::Const(root.WithOpName("const_0d"), 42); + + Output const_2d = ops::Const(root.WithOpName("const_2d"), {{1, 2}, {4, 3}}); + + EXPECT_THAT(const_0d.node(), NodeWith(ConstantValue(42))); + EXPECT_THAT(const_0d.node(), NodeWith(ConstantValue(42), Name("const_0d"))); + + EXPECT_THAT(const_2d.node(), NodeWith(ConstantValue({{1, 2}, {4, 3}}))); + + EXPECT_EQ(Explain(placeholder.node(), NodeWith(ConstantValue(42))), + "\nexpected op Const but found Placeholder"); + EXPECT_EQ( + Explain(const_0d.node(), NodeWith(ConstantValue(43))), + "\nmismatch in constant tensor at index 0 expected = 43 actual = 42"); + EXPECT_EQ( + Explain(const_0d.node(), NodeWith(ConstantValue({{1, 2}, {4, 3}}))), + "\nwas looking for tensor with 4 elements, found tensor with 1 elements"); + EXPECT_EQ( + Explain(const_2d.node(), NodeWith(ConstantValue(42))), + "\nwas looking for tensor with 1 elements, found tensor with 4 elements"); +} + +TEST(NodeMatchers, AssignedDevice) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Output placeholder_a = + ops::Placeholder(root.WithOpName("placeholder_a"), DT_FLOAT); + Output placeholder_b = + ops::Placeholder(root.WithOpName("placeholder_b"), DT_FLOAT); + + Output assigned_add = + ops::Add(root.WithOpName("assigned_add"), placeholder_a, placeholder_b); + assigned_add.node()->set_assigned_device_name( + "/job:localhost/replica:0/task:0/device:CPU:0"); + + Output unassigned_add = + ops::Add(root.WithOpName("unassigned_add"), placeholder_a, placeholder_b); + + EXPECT_THAT( + assigned_add.node(), + NodeWith(AssignedDevice("/job:localhost/replica:0/task:0/device:CPU:0"))); + EXPECT_THAT(unassigned_add.node(), NodeWith(AssignedDevice(""))); + + EXPECT_EQ(Explain(unassigned_add.node(), + NodeWith(AssignedDevice( + "/job:localhost/replica:0/task:0/device:CPU:0"))), + "\nexpected assigned_device " + "/job:localhost/replica:0/task:0/device:CPU:0 but found \"\""); +} + +TEST(NodeMatchers, OutputIndices) { + Scope root = Scope::NewRootScope().ExitOnError(); + Output pred = ops::Placeholder(root.WithOpName("pred"), DT_BOOL); + + Output data = ops::Placeholder(root.WithOpName("data"), DT_FLOAT); + ops::Switch sw(root.WithOpName("switch"), data, pred); + Output add = ops::Add(root.WithOpName("add"), sw.output_true, + ops::Placeholder(root.WithOpName("addend"), DT_FLOAT)); + + EXPECT_THAT(add.node(), NodeWith(Inputs(Out(1, NodeWith(Op("Switch"))), _))); + EXPECT_EQ( + Explain(add.node(), NodeWith(Inputs(Out(0, NodeWith(Op("Switch"))), _))), + "\ninput 0 does not match expected:\nop: Switch, \nexpected output slot " + "to be 0 but found 1"); +} + +TEST(NodeMatchers, Attrs) { + Scope root = Scope::NewRootScope().ExitOnError(); + Output enter = ops::internal::Enter( + root.WithOpName("enter"), + ops::Placeholder(root.WithOpName("data"), DT_FLOAT), "frame_name", + ops::internal::Enter::Attrs{}.IsConstant(true)); + EXPECT_THAT(enter.node(), NodeWith(Attr("is_constant", true))); + EXPECT_EQ(Explain(enter.node(), NodeWith(Attr("is_constant", false))), + "attribute named is_constant does not match value; expected: " + "\"false\", found: \"true\""); + EXPECT_EQ(Explain(enter.node(), NodeWith(Attr("missing_attr", false))), + "did not find attribute named \"missing_attr\" in node"); +} + +} // namespace +} // namespace testing +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/ops/BUILD b/tensorflow/compiler/jit/ops/BUILD index 13804c6a0575b921839f99ef7d142e0871693b5a..f72224545b25bc7100e0b6788e6fbf0a7ca63dad 100644 --- a/tensorflow/compiler/jit/ops/BUILD +++ b/tensorflow/compiler/jit/ops/BUILD @@ -4,9 +4,17 @@ package( default_visibility = ["//tensorflow/compiler/tf2xla:internal"], ) +load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py") + cc_library( name = "xla_ops", srcs = ["xla_ops.cc"], deps = ["//tensorflow/core:framework"], alwayslink = 1, ) + +tf_gen_op_wrapper_py( + name = "xla_ops_wrapper_py", + out = "xla_ops.py", + deps = ["//tensorflow/compiler/jit/ops:xla_ops"], +) diff --git a/tensorflow/compiler/jit/ops/xla_ops.cc b/tensorflow/compiler/jit/ops/xla_ops.cc index f2473d98ffd5dae55983e601b8d2d65af6a6d54c..95d12e95fd9a0d1cca513ee74a0651ea69eba89e 100644 --- a/tensorflow/compiler/jit/ops/xla_ops.cc +++ b/tensorflow/compiler/jit/ops/xla_ops.cc @@ -13,10 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" namespace tensorflow { +using shape_inference::InferenceContext; + REGISTER_OP("XlaLaunch") .Input("constants: Tconstants") .Attr("Tconstants: list(type) >= 0") @@ -32,4 +36,63 @@ REGISTER_OP("XlaLaunch") .SetIsStateful() .Doc("XLA Launch Op. For use by the XLA JIT only."); +REGISTER_OP("XlaClusterOutput") + .Input("input: T") + // Note: when replication is supported, this op will have N outputs. + .Output("outputs: T") + .Attr("T: type") + .SetShapeFn([](InferenceContext* c) { + for (int i = 0; i < c->num_outputs(); ++i) { + c->set_output(i, c->input(0)); + } + return Status::OK(); + }) + .Doc( + "Operator that connects the output of an XLA computation to other " + "consumer graph nodes."); + +REGISTER_OP("_XlaCompile") + .Input("constants: Tconstants") + .Attr("Tconstants: list(type) >= 0") + .Attr("must_compile: bool") + .Input("args: Targs") + .Attr("Targs: list(type) >= 0") + .Input("resources: Nresources * resource") + .Attr("Nresources: int >= 0") + .Output("key: string") + .Output("compilation_successful: bool") + .Attr("function: func") + // The compilation cache is stateful. + .SetIsStateful() + .Doc(R"(XLA Compile Op. For use by the XLA JIT only. + +Compiles a TensorFlow function into an XLA LocalExecutable and returns a key +that _XlaRun can use to look up the LocalExecutable and execute it. + +key: A key that can be used to look up the local executable compiled by the + node and associated metadata. + +compilation_successful: If the `must_compile` attr is false the _XlaCompile op + can decide not to compile the clusters based on some profitability + heuristics. In that case `compilation_successful` is false if _XlaCompile + chose not to compile the cluster. If the `must_compile` attr is true then + _XlaCompile always attempts to compile the cluster and + `compilation_successful` is always true. +)"); + +REGISTER_OP("_XlaRun") + .Input("args: Targs") + .Attr("Targs: list(type) >= 0") + .Output("results: Tresults") + .Attr("Tresults: list(type) >= 0") + .Input("key: string") + // XLA random-number generation ops are stateful. + // TODO(phawkins): create stateful and non-stateful variants of _XlaRun. + .SetIsStateful() + .Doc(R"(XLA Run Op. For use by the XLA JIT only. + +Executes a TensorFlow function previously compiled into a LocalExecutable by an +_XlaCompile op. +)"); + } // namespace tensorflow diff --git a/tensorflow/compiler/jit/partially_decluster_pass.cc b/tensorflow/compiler/jit/partially_decluster_pass.cc index 3a9a8c4988a4d4cef4f67164f87b1f0aba30224f..5b9610322336acbcede0bef0538043b8ff917c16 100644 --- a/tensorflow/compiler/jit/partially_decluster_pass.cc +++ b/tensorflow/compiler/jit/partially_decluster_pass.cc @@ -14,15 +14,21 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/jit/partially_decluster_pass.h" +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_set.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/jit/xla_cluster_util.h" +#include "tensorflow/compiler/tf2xla/const_analysis.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/core/framework/memory_types.h" #include "tensorflow/core/framework/node_def.pb.h" -#include "tensorflow/core/lib/gtl/flatset.h" +#include "tensorflow/core/framework/op_kernel.h" namespace tensorflow { namespace { -Status FindNodesToDecluster(const Graph& graph, gtl::FlatSet* result, - gtl::ArraySlice post_order) { +Status FindNodesToDecluster(const Graph& graph, + absl::flat_hash_set* result, + absl::Span post_order) { // Find nodes that have at least one user outside their cluster that expects // hostmem output. These nodes should be cloned to outside the cluster to // avoid the device-host copy we'd otherwise need. @@ -30,7 +36,7 @@ Status FindNodesToDecluster(const Graph& graph, gtl::FlatSet* result, MemoryTypeVector input_mtypes, output_mtypes; for (Node* n : post_order) { - absl::optional from_cluster = GetXlaClusterForNode(*n); + absl::optional from_cluster = GetXlaClusterForNode(*n); if (!from_cluster) { continue; } @@ -79,7 +85,7 @@ Status FindNodesToDecluster(const Graph& graph, gtl::FlatSet* result, // Check if `dst` is in a different cluster, unclustered, or about to be // partially declustered (here we rely on the post-order traversal order). // If yes, decluster `n` to avoid the device-to-host memcpy. - absl::optional dst_cluster = + absl::optional dst_cluster = result->count(dst) ? absl::nullopt : GetXlaClusterForNode(*dst); if (from_cluster != dst_cluster) { CHECK(result->insert(n).second); @@ -91,15 +97,16 @@ Status FindNodesToDecluster(const Graph& graph, gtl::FlatSet* result, } Status PartiallyDeclusterNode(Graph* graph, Node* n) { - StringPiece cluster_name = *GetXlaClusterForNode(*n); - gtl::InlinedVector out_edges_to_clone; + absl::string_view cluster_name = *GetXlaClusterForNode(*n); + absl::InlinedVector out_edges_to_clone; for (const Edge* out_edge : n->out_edges()) { if (out_edge->IsControlEdge()) { continue; } Node* dst = out_edge->dst(); - absl::optional dst_cluster_name = GetXlaClusterForNode(*dst); + absl::optional dst_cluster_name = + GetXlaClusterForNode(*dst); if (dst_cluster_name != cluster_name) { out_edges_to_clone.push_back(out_edge); } @@ -108,7 +115,7 @@ Status PartiallyDeclusterNode(Graph* graph, Node* n) { CHECK(!out_edges_to_clone.empty()) << n->DebugString(); NodeDef ndef = n->def(); - ndef.set_name(strings::StrCat(n->name(), "/declustered")); + ndef.set_name(absl::StrCat(n->name(), "/declustered")); RemoveFromXlaCluster(&ndef); Status s; Node* cloned_node = graph->AddNode(ndef, &s); @@ -128,30 +135,47 @@ Status PartiallyDeclusterNode(Graph* graph, Node* n) { return Status::OK(); } -} // namespace - -Status PartiallyDeclusterPass::Run( - const GraphOptimizationPassOptions& options) { - // NB! In this pass we assume the only XLA-auto-clusterable operations that - // may have side effects are resource variable operations so we don't cluster - // those. The pass will have to be updated if this assumption becomes - // invalid. - Graph* graph = options.graph->get(); +bool NotBackedge(const Edge& edge) { return !edge.src()->IsNextIteration(); } +// Clones nodes to outside their cluster to avoid device-to-host copies. For +// instance, converts this: +// +// ..... +// | +// v +// A_Clustered ====> C_Unclustered +// | +// v +// B_Clustered +// +// to: +// +// ..... +// | | +// | +-------------+ +// | | +// v v +// A_Clustered A_Unclustered ====> C_Unclustered +// | +// v +// B_Clustered +// +// where the ===> arrow has a hostmem source and destination and would entail a +// device to host copy if the source and destination were not in the same XLA +// cluster. +Status PartiallyDeclusterToRemoveDeviceToHostCopies(Graph* graph) { // When deciding whether to decluster a particular node, we base our decision // on if we've decided that some of its consumers have to be declustered too. // Iterating the graph in post-order guarantees that consumers have been // visited before producers. std::vector post_order; GetPostOrder(*graph, &post_order, /*stable_comparator=*/NodeComparatorName(), - /*edge_filter=*/[](const Edge& edge) { - return !edge.src()->IsNextIteration(); - }); + /*edge_filter=*/NotBackedge); - gtl::FlatSet nodes_to_partially_decluster; - TF_RETURN_IF_ERROR(FindNodesToDecluster( - **options.graph, &nodes_to_partially_decluster, post_order)); + absl::flat_hash_set nodes_to_partially_decluster; + TF_RETURN_IF_ERROR( + FindNodesToDecluster(*graph, &nodes_to_partially_decluster, post_order)); if (VLOG_IS_ON(3)) { for (Node* n : post_order) { @@ -168,10 +192,142 @@ Status PartiallyDeclusterPass::Run( } nodes_to_partially_decluster.clear(); - TF_RETURN_IF_ERROR(FindNodesToDecluster( - **options.graph, &nodes_to_partially_decluster, post_order)); + TF_RETURN_IF_ERROR( + FindNodesToDecluster(*graph, &nodes_to_partially_decluster, post_order)); CHECK(nodes_to_partially_decluster.empty()); return Status::OK(); } + +bool IsIntraClusterEdge(const Edge& edge) { + absl::optional src_cluster_name = + GetXlaClusterForNode(*edge.src()); + absl::optional dst_cluster_name = + GetXlaClusterForNode(*edge.dst()); + return src_cluster_name.has_value() && src_cluster_name == dst_cluster_name; +} + +bool IsMustCompileDevice(const DeviceType& device_type) { + const XlaOpRegistry::DeviceRegistration* registration; + if (XlaOpRegistry::GetCompilationDevice(device_type.type(), ®istration)) { + return registration->requires_compilation; + } + + return false; +} + +Status MustCompileNode(const Node* n, bool* must_compile) { + DeviceType device_type(""); + TF_RETURN_IF_ERROR( + DeviceToDeviceType(n->assigned_device_name(), &device_type)); + + if (IsMustCompileDevice(device_type)) { + *must_compile = true; + return Status::OK(); + } + + // We must compile `n` if it does not have a TensorFlow kernel. + *must_compile = !FindKernelDef(device_type, n->def(), nullptr, nullptr).ok(); + return Status::OK(); +} + +// Declusters nodes to reduce the number of times we think we need to recompile +// a TensorFlow graph. +// +// Abstractly, if we have a cluster of this form: +// +// x0 = arg0 +// x1 = arg1 +// ... +// shape = f(x0, x1, ...) +// result = Reshape(input=, new_shape=shape) +// +// then pulling `f` out of the cluster may reduce the number of compilations and +// will never increase the number of compilations. +// +// We may reduce the number of compilations if f is many to one. For instance +// if f(x,y) = x-y then x=3,y=1 and x=4,y=2 will generate two different +// compilations if f is in the cluster but only one compilation if f is outside +// the cluster. +// +// Declustering f will increase the number of compilations only if f is a +// one-to-many "function" i.e. isn't a function at all. RNG is one possible +// example, depending on how we look at it. But we never create clusters where +// such f's would be marked as must-be-constant. +// +// We assume here that the extra repeated (repeated compared to a clustered f +// where it will always be constant folded) host-side computation of f does not +// regress performance in any significant manner. We will have to revisit this +// algorith with a more complex cost model if this assumption turns out to be +// incorrect. +Status DeclusterNodesToReduceRecompilations(Graph* graph) { + std::vector compile_time_const_nodes(graph->num_node_ids()); + TF_RETURN_IF_ERROR(BackwardsConstAnalysis( + *graph, nullptr, &compile_time_const_nodes, IsIntraClusterEdge)); + + std::vector rpo; + GetReversePostOrder(*graph, &rpo, /*stable_comparator=*/NodeComparatorName(), + /*edge_filter=*/NotBackedge); + for (Node* n : rpo) { + if (!compile_time_const_nodes[n->id()]) { + continue; + } + + absl::string_view cluster_name = *GetXlaClusterForNode(*n); + bool node_on_cluster_edge = + absl::c_all_of(n->in_edges(), [&](const Edge* e) { + absl::optional incoming_cluster = + GetXlaClusterForNode(*e->src()); + return !incoming_cluster || *incoming_cluster != cluster_name; + }); + + // We don't want to decluster F in a graph like + // + // Input -> OP -> Shape -> F -> Reshape + // + // Doing so will break up the cluster. Even if we were okay with breaking + // up the cluster we will at least have to relabel the two clusters to have + // different cluster names. + // + // We may want to revisit this in the future: we may have cases where OP is + // a small computation that does not benefit from XLA while XLA can optimize + // everything that follows the Reshape. In these cases it may be wise to + // remove Input, OP, Shape and F from the cluster, if F is a many-to-one + // function. + // + // Note that we do do the right thing for graphs like: + // + // Input -> F0 -> F1 -> Reshape + // + // Since we iterate in RPO, we'll first encounter F0, decluster it, then + // encounter F1, decluster it and so on. + if (node_on_cluster_edge) { + bool must_compile_node; + TF_RETURN_IF_ERROR(MustCompileNode(n, &must_compile_node)); + if (!must_compile_node) { + VLOG(3) << "Declustering must-be-constant node " << n->name(); + RemoveFromXlaCluster(n); + } + } + } + + return Status::OK(); +} + +} // namespace + +Status PartiallyDeclusterPass::Run( + const GraphOptimizationPassOptions& options) { + // NB! In this pass we assume the only XLA-auto-clusterable operations that + // may have side effects are resource variable operations so we don't cluster + // those. The pass will have to be updated if this assumption becomes + // invalid. + + Graph* graph = options.graph->get(); + + TF_RETURN_IF_ERROR(PartiallyDeclusterToRemoveDeviceToHostCopies(graph)); + TF_RETURN_IF_ERROR(DeclusterNodesToReduceRecompilations(graph)); + + return Status::OK(); +} } // namespace tensorflow diff --git a/tensorflow/compiler/jit/partially_decluster_pass.h b/tensorflow/compiler/jit/partially_decluster_pass.h index 6949b5028ee55e182b27589f9a9711dad7839e86..cfc4ddb5630bec91d6942c983ce1efae3a735c43 100644 --- a/tensorflow/compiler/jit/partially_decluster_pass.h +++ b/tensorflow/compiler/jit/partially_decluster_pass.h @@ -20,34 +20,11 @@ limitations under the License. namespace tensorflow { -// Clones nodes from within a cluster to outside the cluster if profitable. +// Clones or moves nodes from within a cluster to outside the cluster if +// profitable. There are two reasons why we do this: // -// Today this only clones to avoid device-to-host copies, but in the future we -// may consider other reasons to clone. For instance, we convert this: -// -// ..... -// | -// v -// A_Clustered ====> C_Unclustered -// | -// v -// B_Clustered -// -// to: -// -// ..... -// | | -// | +-------------+ -// | | -// v v -// A_Clustered A_Unclustered ====> C_Unclustered -// | -// v -// B_Clustered -// -// where the ===> arrow has a hostmem source and destination and would entail a -// device to host copy if the source and destination were not in the same XLA -// cluster. +// - Reducing device-to-host copies. +// - Reducing the number of XLA recompilations. class PartiallyDeclusterPass : public GraphOptimizationPass { public: Status Run(const GraphOptimizationPassOptions& options) override; diff --git a/tensorflow/compiler/jit/partially_decluster_pass_test.cc b/tensorflow/compiler/jit/partially_decluster_pass_test.cc index f61a955c222dd7ce11a177cd54bb8851a5400496..74d5ef57184197ad6e9e5048722e84863756a3f5 100644 --- a/tensorflow/compiler/jit/partially_decluster_pass_test.cc +++ b/tensorflow/compiler/jit/partially_decluster_pass_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/jit/partially_decluster_pass.h" +#include "absl/memory/memory.h" #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/ops/array_ops.h" #include "tensorflow/cc/ops/control_flow_ops_internal.h" @@ -23,6 +24,7 @@ limitations under the License. #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/xla_cluster_util.h" +#include "tensorflow/compiler/tf2xla/cc/ops/xla_ops.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/core/framework/node_def_util.h" @@ -31,6 +33,7 @@ limitations under the License. #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/graph/graph_def_builder_util.h" +#include "tensorflow/core/grappler/optimizers/data/graph_utils.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" @@ -58,9 +61,9 @@ class FakeBinaryOp : public OpKernel { void Compute(OpKernelContext* ctx) override { CHECK(false); } }; -class FakeResourceVarUpdateOp : public OpKernel { +class FakeResourceUpdateOp : public OpKernel { public: - explicit FakeResourceVarUpdateOp(OpKernelConstruction* context) + explicit FakeResourceUpdateOp(OpKernelConstruction* context) : OpKernel(context) {} void Compute(OpKernelContext* ctx) override { CHECK(false); } @@ -72,17 +75,18 @@ REGISTER_KERNEL_BUILDER(Name("FakeBinary") .HostMemory("host_out"), FakeBinaryOp); -REGISTER_KERNEL_BUILDER(Name("FakeResourceVarUpdate") - .Device(DEVICE_CPU) - .HostMemory("something_else"), - FakeResourceVarUpdateOp); +REGISTER_KERNEL_BUILDER( + Name("FakeResourceUpdate").Device(DEVICE_CPU).HostMemory("something_else"), + FakeResourceUpdateOp); Status PartiallyDecluster(std::unique_ptr* graph) { FixupSourceAndSinkEdges(graph->get()); // Assign all nodes to the CPU device. static const char* kCpuDevice = "/job:localhost/replica:0/task:0/cpu:0"; for (Node* n : (*graph)->nodes()) { - n->set_assigned_device_name(kCpuDevice); + if (n->assigned_device_name().empty()) { + n->set_assigned_device_name(kCpuDevice); + } } GraphOptimizationPassOptions opt_options; @@ -91,8 +95,8 @@ Status PartiallyDecluster(std::unique_ptr* graph) { return pass.Run(opt_options); } -const Node* FindNodeByName(const Graph& graph, const string& name) { - for (const Node* node : graph.nodes()) { +Node* FindNodeByName(const Graph& graph, const string& name) { + for (Node* node : graph.nodes()) { if (node->name() == name) { return node; } @@ -279,5 +283,159 @@ TEST(PartiallyDeclusterPassTest, DeclusterDependentNodes) { "ClusteredProducer0/declustered"); EXPECT_EQ(declustered_producer_1_inputs[1]->name(), "Input"); } + +void AddToCluster(absl::Span nodes, + absl::string_view cluster_name) { + for (Node* n : nodes) { + n->AddAttr(kXlaClusterAttr, string(cluster_name)); + } +} + +TEST(PartiallyDeclusterPassTest, DeclusterMustBeConstantNodes) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output shape_a = ops::Placeholder(s.WithOpName("shape_a"), DT_INT32, + ops::Placeholder::Attrs{}); + Output shape_b = ops::Placeholder(s.WithOpName("shape_b"), DT_INT32, + ops::Placeholder::Attrs{}); + Output shape = ops::Add(s.WithOpName("shape"), shape_a, shape_b); + + Output reshape_input = ops::Placeholder(s.WithOpName("reshape_input"), + DT_FLOAT, ops::Placeholder::Attrs{}); + Output reshape = ops::Reshape(s.WithOpName("reshape"), reshape_input, shape); + + AddToCluster({shape.node(), reshape.node()}, "cluster_0"); + + auto graph = absl::make_unique(OpRegistry::Global()); + TF_ASSERT_OK(s.ToGraph(graph.get())); + TF_ASSERT_OK(PartiallyDecluster(&graph)); + + const Node* n = FindNodeByName(*graph, "shape"); + ASSERT_NE(n, nullptr); + + EXPECT_EQ(GetXlaClusterForNode(*n), absl::nullopt); +} + +TEST(PartiallyDeclusterPassTest, DeclusteringStopsAtMetadataOps) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output input_a = ops::Placeholder(s.WithOpName("input_a"), DT_INT32, + ops::Placeholder::Attrs{}); + Output input_b = ops::Placeholder(s.WithOpName("shape_b"), DT_FLOAT, + ops::Placeholder::Attrs{}); + Output mul = ops::Mul(s.WithOpName("mul"), input_b, input_b); + Output shape_of_mul = ops::Shape(s.WithOpName("shape_of_mul"), mul); + + Output shape = ops::Add(s.WithOpName("shape"), shape_of_mul, input_a); + + Output reshape_input = ops::Placeholder(s.WithOpName("reshape_input"), + DT_FLOAT, ops::Placeholder::Attrs{}); + Output reshape = ops::Reshape(s.WithOpName("reshape"), reshape_input, shape); + + AddToCluster({mul.node(), shape_of_mul.node(), shape.node(), reshape.node()}, + "cluster_0"); + + std::unique_ptr graph = absl::make_unique(OpRegistry::Global()); + TF_ASSERT_OK(s.ToGraph(graph.get())); + TF_ASSERT_OK(PartiallyDecluster(&graph)); + + const Node* n = FindNodeByName(*graph, "shape"); + ASSERT_NE(n, nullptr); + + EXPECT_EQ(GetXlaClusterForNode(*n), "cluster_0"); +} + +TEST(PartiallyDeclusterPassTest, EdgeAcrossDifferentClusters) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output shape_a = ops::Placeholder(s.WithOpName("shape_a"), DT_INT32, + ops::Placeholder::Attrs{}); + Output shape_b = ops::Placeholder(s.WithOpName("shape_b"), DT_INT32, + ops::Placeholder::Attrs{}); + Output shape = ops::Add(s.WithOpName("shape"), shape_a, shape_b); + + Output reshape_input = ops::Placeholder(s.WithOpName("reshape_input"), + DT_FLOAT, ops::Placeholder::Attrs{}); + Output reshape = ops::Reshape(s.WithOpName("reshape"), reshape_input, shape); + + AddToCluster({reshape.node()}, "cluster_0"); + AddToCluster({shape.node()}, "cluster_1"); + + std::unique_ptr graph = absl::make_unique(OpRegistry::Global()); + TF_ASSERT_OK(s.ToGraph(graph.get())); + TF_ASSERT_OK(PartiallyDecluster(&graph)); + + const Node* n = FindNodeByName(*graph, "shape"); + ASSERT_NE(n, nullptr); + + EXPECT_EQ(GetXlaClusterForNode(*n), "cluster_1"); +} + +TEST(PartiallyDeclusterPassTest, DontDeclusterXlaDeviceOps) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output shape_a = ops::Placeholder(s.WithOpName("shape_a"), DT_INT32, + ops::Placeholder::Attrs{}); + Output shape_b = ops::Placeholder(s.WithOpName("shape_b"), DT_INT32, + ops::Placeholder::Attrs{}); + Output shape = ops::Add(s.WithOpName("shape"), shape_a, shape_b); + + Output reshape_input = ops::Placeholder(s.WithOpName("reshape_input"), + DT_FLOAT, ops::Placeholder::Attrs{}); + Output reshape = ops::Reshape(s.WithOpName("reshape"), reshape_input, shape); + + AddToCluster({shape.node(), reshape.node()}, "cluster_0"); + + std::unique_ptr graph = absl::make_unique(OpRegistry::Global()); + TF_ASSERT_OK(s.ToGraph(graph.get())); + + // This is needed to register the XLA_GPU device. + std::vector devices; + TF_ASSERT_OK(DeviceFactory::AddDevices( + SessionOptions(), "/job:localhost/replica:0/task:0", &devices)); + + // Scope::ToGraph loses the assigned device name since it goes through + // GraphDef/NodeDef which does not have a field for the assigned device name. + Node* n = FindNodeByName(*graph, "shape"); + ASSERT_NE(n, nullptr); + n->set_assigned_device_name( + "/job:localhost/replica:0/task:0/device:XLA_GPU:0"); + + TF_ASSERT_OK(PartiallyDecluster(&graph)); + + EXPECT_EQ(GetXlaClusterForNode(*n), "cluster_0"); + + for (Device* d : devices) { + delete d; + } +} + +TEST(PartiallyDeclusterPassTest, DontDeclusterNonTensorFlowOps) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output dynamic_slice_operand = + ops::Placeholder(s.WithOpName("dynamic_slice_operand"), DT_INT32, + ops::Placeholder::Attrs{}); + Output dynamic_slice_begin = ops::Placeholder( + s.WithOpName("dynamic_slice_begin"), DT_INT32, ops::Placeholder::Attrs{}); + Output dynamic_slice_size = ops::Placeholder( + s.WithOpName("dynamic_slice_size"), DT_INT32, ops::Placeholder::Attrs{}); + Output dynamic_slice = + ops::XlaDynamicSlice(s.WithOpName("dynamic_slice"), dynamic_slice_operand, + dynamic_slice_begin, dynamic_slice_size); + + Output reshape_input = ops::Placeholder(s.WithOpName("reshape_input"), + DT_FLOAT, ops::Placeholder::Attrs{}); + Output reshape = + ops::Reshape(s.WithOpName("reshape"), reshape_input, dynamic_slice); + + AddToCluster({dynamic_slice.node(), reshape.node()}, "cluster_0"); + + std::unique_ptr graph = absl::make_unique(OpRegistry::Global()); + TF_ASSERT_OK(s.ToGraph(graph.get())); + + Node* n = FindNodeByName(*graph, "dynamic_slice"); + ASSERT_NE(n, nullptr); + + TF_ASSERT_OK(PartiallyDecluster(&graph)); + + EXPECT_EQ(GetXlaClusterForNode(*n), "cluster_0"); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/jit/resource_operation_safety_analysis.cc b/tensorflow/compiler/jit/resource_operation_safety_analysis.cc index 1ba4a5ef7399111e512da8c4966f5899ed828b17..e039d46ec863920eb7deb5bc20525fdab866415c 100644 --- a/tensorflow/compiler/jit/resource_operation_safety_analysis.cc +++ b/tensorflow/compiler/jit/resource_operation_safety_analysis.cc @@ -82,6 +82,7 @@ limitations under the License. #include "tensorflow/compiler/jit/resource_operation_safety_analysis.h" +#include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "absl/strings/str_join.h" #include "absl/types/optional.h" @@ -89,8 +90,6 @@ limitations under the License. #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/tensor_id.h" -#include "tensorflow/core/lib/gtl/flatmap.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/util/ptr_util.h" @@ -165,7 +164,7 @@ bool IsEdgeSafe(XlaResourceOpKind from, XlaResourceOpKind to) { using ResourceOp = std::pair; string ResourceOpToString(const ResourceOp& resource_op) { - return strings::StrCat( + return absl::StrCat( resource_op.first, ": ", XlaResourceOpInfo::XlaResourceOpKindToString(resource_op.second)); } @@ -177,7 +176,7 @@ string ResourceOpToString(const ResourceOp& resource_op) { // point. class ResourceOpSet { private: - using Impl = gtl::FlatSet; + using Impl = absl::flat_hash_set; public: ResourceOpSet() = default; @@ -257,11 +256,11 @@ string ResourceOpSetToString(const ResourceOpSet& resource_op_set) { std::vector elements_debug_string; std::transform(resource_op_set.begin(), resource_op_set.end(), std::back_inserter(elements_debug_string), ResourceOpToString); - return strings::StrCat("{", absl::StrJoin(elements_debug_string, ","), "}"); + return absl::StrCat("{", absl::StrJoin(elements_debug_string, ","), "}"); } string NodeToString(const Node& n, XlaResourceOpKind resource_op_kind) { - return strings::StrCat( + return absl::StrCat( "[", n.name(), ": ", n.type_string(), "(", XlaResourceOpInfo::XlaResourceOpKindToString(resource_op_kind), ")", "]"); } diff --git a/tensorflow/compiler/jit/xla_cluster_util.cc b/tensorflow/compiler/jit/xla_cluster_util.cc index 4f2fabd658330b8ab182e13e02ed0bca41641e46..f85121ca27ad3da918315f93b28e9000dfd65e67 100644 --- a/tensorflow/compiler/jit/xla_cluster_util.cc +++ b/tensorflow/compiler/jit/xla_cluster_util.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/jit/resource_operation_safety_analysis.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/graph/control_flow.h" @@ -52,8 +53,8 @@ string DescribeCycle(const GraphCycles* cycles, const Graph& graph, int src, }; string description; - strings::StrAppend(&description, "Edge from ", node_name(src), " to ", - node_name(dst), " would create a cycle.\n"); + absl::StrAppend(&description, "Edge from ", node_name(src), " to ", + node_name(dst), " would create a cycle.\n"); path.resize(path_size); for (int32 node_id : path) { string ascii_art; @@ -64,7 +65,7 @@ string DescribeCycle(const GraphCycles* cycles, const Graph& graph, int src, } else { ascii_art = "+-- "; } - strings::StrAppend(&description, ascii_art, node_name(node_id), "\n"); + absl::StrAppend(&description, ascii_art, node_name(node_id), "\n"); } return description; } @@ -186,7 +187,7 @@ Status CreateCycleDetectionGraph(const Graph* graph, GraphCycles* cycles) { return Status::OK(); } -absl::optional GetXlaClusterForNode(const Node& node) { +absl::optional GetXlaClusterForNode(const Node& node) { const AttrValue* attr_value = node.attrs().Find(kXlaClusterAttr); if (attr_value == nullptr) { return absl::nullopt; @@ -209,6 +210,8 @@ void RemoveFromXlaCluster(NodeDef* node_def) { node_def->mutable_attr()->erase(kXlaClusterAttr); } +void RemoveFromXlaCluster(Node* node) { node->ClearAttr(kXlaClusterAttr); } + Status AdjustCycleDetectionGraphForResourceOps( const Graph* graph, const FunctionLibraryDefinition* flib_def, const std::function& resource_ops_to_ignore, diff --git a/tensorflow/compiler/jit/xla_cluster_util.h b/tensorflow/compiler/jit/xla_cluster_util.h index b0439a63ca6476b6b1d63e65308712270381dd9f..ba218f3315d2607c47342fdade0403678faa2362 100644 --- a/tensorflow/compiler/jit/xla_cluster_util.h +++ b/tensorflow/compiler/jit/xla_cluster_util.h @@ -47,11 +47,14 @@ Status CreateCycleDetectionGraph(const Graph* graph, GraphCycles* cycles); // Returns the XLA cluster in which `node` is placed if it is in an XLA cluster, // otherwise returns nullopt. -absl::optional GetXlaClusterForNode(const Node& node); +absl::optional GetXlaClusterForNode(const Node& node); // Removes `node_def` its XLA cluster (by clearing its _XlaCluster attribute). void RemoveFromXlaCluster(NodeDef* node_def); +// Removes `node` its XLA cluster (by clearing its _XlaCluster attribute). +void RemoveFromXlaCluster(Node* node); + // Returns true if `node` has a DT_RESOURCE typed input or output. bool HasResourceInputOrOutput(const Node& node); diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc index ef6b0e67d3c4007f86dc7eef89cacb4cea98fc15..826e98b96620165604594a22b81cd02422605c12 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.cc +++ b/tensorflow/compiler/jit/xla_compilation_cache.cc @@ -40,6 +40,7 @@ namespace tensorflow { XlaCompilationCache::XlaCompilationCache(xla::LocalClient* client, DeviceType device_type) : client_(client), device_type_(std::move(device_type)) {} + XlaCompilationCache::~XlaCompilationCache() { // Ensure any use of our programs have completed by waiting for all stream // executors to complete. @@ -67,12 +68,12 @@ string XlaCompilationCache::DebugString() { string XlaCompilationCache::SignatureDebugString(const Signature& sig) { string result = sig.name; for (const auto& a : sig.arg_types) { - strings::StrAppend(&result, ",", DataTypeString(a.first), - a.second.DebugString()); + absl::StrAppend(&result, ",", DataTypeString(a.first), + a.second.DebugString()); } for (const auto& v : sig.arg_values) { - strings::StrAppend(&result, "; ", v.DebugString()); + absl::StrAppend(&result, "; ", v.DebugString()); } return result; } @@ -228,38 +229,46 @@ Status XlaCompilationCache::Compile( const XlaCompiler::Options& options, const NameAttrList& function, const std::map& constant_args, const std::map& variable_args, OpKernelContext* ctx, - const XlaCompiler::CompilationResult** compilation_result, - xla::LocalExecutable** executable, - const XlaCompiler::CompileOptions& compile_options) { + const XlaCompiler::CompileOptions& compile_options, + CompileMode compile_mode, + const XlaCompiler::CompilationResult** out_compilation_result, + xla::LocalExecutable** out_executable) { + // Set the compile threshold to 1 to implement CompileMode::kStrict. + int64 compile_threshold = + compile_mode == CompileMode::kLazy ? kDefaultCompilationThreshold : 1; return CompileImpl(options, function, constant_args, variable_args, ctx, - compilation_result, executable, compile_options, false); + compile_options, /*compile_single_op=*/false, + /*compile_threshold=*/compile_threshold, + out_compilation_result, out_executable); } Status XlaCompilationCache::CompileSingleOp( const XlaCompiler::Options& options, const std::map& constant_args, const std::map& variable_args, OpKernelContext* ctx, - const XlaCompiler::CompilationResult** compilation_result, - xla::LocalExecutable** executable, - const XlaCompiler::CompileOptions& compile_options) { + const XlaCompiler::CompileOptions& compile_options, + const XlaCompiler::CompilationResult** out_compilation_result, + xla::LocalExecutable** out_executable) { const NodeDef& def = ctx->op_kernel().def(); NameAttrList name; name.set_name(def.op()); *name.mutable_attr() = def.attr(); return CompileImpl(options, name, constant_args, variable_args, ctx, - compilation_result, executable, compile_options, true); + compile_options, + /*compile_single_op=*/true, /*compile_threshold=*/1, + out_compilation_result, out_executable); } Status XlaCompilationCache::CompileImpl( const XlaCompiler::Options& options, const NameAttrList& function, const std::map& constant_args, const std::map& variable_args, OpKernelContext* ctx, - const XlaCompiler::CompilationResult** compilation_result, - xla::LocalExecutable** executable, - const XlaCompiler::CompileOptions& compile_options, - bool compile_single_op) { - CHECK_NE(executable, nullptr); - VLOG(1) << "XlaCompilationCache::Compile " << DebugString(); + const XlaCompiler::CompileOptions& compile_options, bool compile_single_op, + int64 compile_threshold, + const XlaCompiler::CompilationResult** out_compilation_result, + xla::LocalExecutable** out_executable) { + DCHECK_NE(out_executable, nullptr); + VLOG(2) << "XlaCompilationCache::Compile " << DebugString(); if (VLOG_IS_ON(2)) { VLOG(2) << "num_inputs=" << ctx->num_inputs() @@ -309,9 +318,18 @@ Status XlaCompilationCache::CompileImpl( // TODO(phawkins): this locking will need to be restructured when we implement // cache eviction. mutex_lock entry_lock(entry->mu); + int64 current_request_count = ++entry->request_count; if (!entry->compiled) { - VLOG(1) << "Compilation cache miss for signature: " - << SignatureDebugString(signature); + VLOG(2) << "Compilation cache miss for signature: " + << SignatureDebugString(signature) << " with request count " + << current_request_count << " and compile threshold " + << compile_threshold; + if (current_request_count < compile_threshold) { + *out_compilation_result = nullptr; + *out_executable = nullptr; + return Status::OK(); + } + tensorflow::Env* env = tensorflow::Env::Default(); const uint64 compile_start_us = env->NowMicros(); // Do the actual JIT compilation without holding the lock (it can take @@ -357,8 +375,8 @@ Status XlaCompilationCache::CompileImpl( } } TF_RETURN_IF_ERROR(entry->compilation_status); - *compilation_result = &entry->compilation_result; - *executable = entry->executable.get(); + *out_compilation_result = &entry->compilation_result; + *out_executable = entry->executable.get(); return Status::OK(); } diff --git a/tensorflow/compiler/jit/xla_compilation_cache.h b/tensorflow/compiler/jit/xla_compilation_cache.h index 10ad87e38cc4d614e869782329f84351bc3b1f0b..f06a991818db53adb3e5c0cc483c6180128a87e7 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.h +++ b/tensorflow/compiler/jit/xla_compilation_cache.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_JIT_XLA_COMPILATION_CACHE_H_ #define TENSORFLOW_COMPILER_JIT_XLA_COMPILATION_CACHE_H_ +#include "absl/container/flat_hash_map.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/xla/client/local_client.h" @@ -24,7 +25,6 @@ limitations under the License. #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/lib/core/threadpool.h" -#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/thread_annotations.h" @@ -50,6 +50,11 @@ class XlaCompilationCache : public ResourceBase { XlaCompilationCache(xla::LocalClient* client, DeviceType device_type); ~XlaCompilationCache() override; + enum class CompileMode { + kLazy, + kStrict, + }; + // Compiles a function into a XlaCompiler::CompilationResult that can be used // to execute an XLA Computation. Compilation results are cached. // `function` is the name of a Tensorflow function to compile. @@ -58,6 +63,14 @@ class XlaCompilationCache : public ResourceBase { // `variable_args` is a snapshot of the current values of the // resource variable arguments to `function`; uninitialized variables are // represented by an absent OptionalTensor. + // + // `compile_mode` controls the behavior of the compilation cache on a cache + // miss. If `compile_mode` is `kLazy` then, based on some profitability + // heuristics, the compilation cache may decide not to compile the cluster at + // this time. In this case it returns null into both `out_compilation_result` + // and `out_executable`. If `compile_mode` is `kStrict` then the compilation + // cache always attempts the compilation on a cache miss. + // // The result of compilation is written to `*compilation_result`, which must // be non-null. If `executable` is non-null, also builds an // xla::LocalExecutable and sets `executable` to point to it. The resulting @@ -68,9 +81,10 @@ class XlaCompilationCache : public ResourceBase { const std::map& constant_args, const std::map& variable_args, OpKernelContext* ctx, - const XlaCompiler::CompilationResult** compilation_result, - xla::LocalExecutable** executable, - const XlaCompiler::CompileOptions& compile_options); + const XlaCompiler::CompileOptions& compile_options, + CompileMode compile_mode, + const XlaCompiler::CompilationResult** out_compilation_result, + xla::LocalExecutable** out_executable); // As above, but calls XlaCompiler::CompileSingleOp instead of // XlaCompiler::CompileFunction. @@ -78,9 +92,9 @@ class XlaCompilationCache : public ResourceBase { const XlaCompiler::Options& options, const std::map& constant_args, const std::map& variable_args, OpKernelContext* ctx, - const XlaCompiler::CompilationResult** compilation_result, - xla::LocalExecutable** executable, - const XlaCompiler::CompileOptions& compile_options); + const XlaCompiler::CompileOptions& compile_options, + const XlaCompiler::CompilationResult** out_compilation_result, + xla::LocalExecutable** out_executable); xla::LocalClient* client() const { return client_; } const DeviceType& device_type() const { return device_type_; } @@ -89,15 +103,14 @@ class XlaCompilationCache : public ResourceBase { private: // Common implementation of Compile and CompileSingleOp. - Status CompileImpl(const XlaCompiler::Options& options, - const NameAttrList& function, - const std::map& constant_args, - const std::map& variable_args, - OpKernelContext* ctx, - const XlaCompiler::CompilationResult** compilation_result, - xla::LocalExecutable** executable, - const XlaCompiler::CompileOptions& compile_options, - bool compile_single_op); + Status CompileImpl( + const XlaCompiler::Options& options, const NameAttrList& function, + const std::map& constant_args, + const std::map& variable_args, OpKernelContext* ctx, + const XlaCompiler::CompileOptions& compile_options, + bool compile_single_op, int64 compile_threshold, + const XlaCompiler::CompilationResult** out_compilation_result, + xla::LocalExecutable** out_executable); // Takes `result` which has been compiled from a Tensorflow subgraph to a // XLA computation already, and generates an XLA LocalExecutable `executable`. @@ -140,6 +153,9 @@ class XlaCompilationCache : public ResourceBase { // Have we tried compiling this entry? bool compiled = false; + // The number of times a compilation with this signature has been requested. + int64 request_count = 0; + // Did compilation succeed? Status compilation_status GUARDED_BY(mu); @@ -152,7 +168,7 @@ class XlaCompilationCache : public ResourceBase { }; mutex compile_cache_mu_; - gtl::FlatMap, Signature::Hash> cache_ + absl::flat_hash_map, Signature::Hash> cache_ GUARDED_BY(compile_cache_mu_); struct CompileStats { @@ -165,9 +181,13 @@ class XlaCompilationCache : public ResourceBase { mutex compile_stats_mu_; // Maps cluster names to compilation statistics for said cluster. - gtl::FlatMap compile_stats_ + absl::flat_hash_map compile_stats_ GUARDED_BY(compile_stats_mu_); + // The number of times a lazy compilation must be requested for a specific + // signature before we attempt to compile it. + static constexpr int64 kDefaultCompilationThreshold = 2; + TF_DISALLOW_COPY_AND_ASSIGN(XlaCompilationCache); }; diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc index 3ba48e8c318f84a4691fb74434bc009fdd0d81bf..79976c85dff200ce993ebb06e7a20a15b71f6085 100644 --- a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc +++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc @@ -34,6 +34,7 @@ std::map GetVariables(OpKernelContext* ctx) { OptionalTensor& optional = variables[i]; optional.name = handle.name(); if (LookupResource(ctx, handle, &variable).ok()) { + core::ScopedUnref scoped_unref(variable); tf_shared_lock lock(*variable->mu()); optional.present = true; optional.value = *variable->tensor(); @@ -58,7 +59,8 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx, /*allocate_xla_tensors=*/true, /*use_multiple_streams=*/metadata.UseMultipleStreams()); - launch_context.PopulateInputs(ctx, result, variables); + launch_context.PopulateInputs(ctx, result, variables, + /*missing_ctx_input_prefix=*/0); se::Stream* stream = ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr; @@ -79,7 +81,8 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx, TF_RETURN_IF_ERROR(run_result.status()); TF_RETURN_IF_ERROR(launch_context.PopulateOutputs( - ctx, result, run_result.ConsumeValueOrDie())); + ctx, result, run_result.ConsumeValueOrDie(), + /*missing_ctx_input_prefix=*/0)); return Status::OK(); } @@ -177,7 +180,7 @@ Status XlaCompileOnDemandOp::Compile( std::map variable_args = GetVariables(ctx); return cache->CompileSingleOp(options, constant_arguments, variable_args, ctx, - result, executable, compile_options); + compile_options, result, executable); } void XlaCompileOnDemandOp::Compute(OpKernelContext* ctx) { diff --git a/tensorflow/compiler/jit/xla_cpu_device.cc b/tensorflow/compiler/jit/xla_cpu_device.cc index 7e159e3171113b0d53f03bb676ac9c21db7fe77a..003c1d8081a3313fd042cdcaea14508ed1048da3 100644 --- a/tensorflow/compiler/jit/xla_cpu_device.cc +++ b/tensorflow/compiler/jit/xla_cpu_device.cc @@ -16,7 +16,7 @@ limitations under the License. // Registers the XLA_CPU device, which is an XlaDevice instantiation that runs // operators using XLA via the XLA "Host" (CPU) backend. -#include "tensorflow/compiler/jit/kernels/xla_launch_op.h" +#include "tensorflow/compiler/jit/kernels/xla_ops.h" #include "tensorflow/compiler/jit/legacy_flags/xla_device_flags.h" #include "tensorflow/compiler/jit/xla_compile_on_demand_op.h" #include "tensorflow/compiler/jit/xla_device.h" @@ -65,10 +65,14 @@ REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_XLA_CPU, XlaCpuDeviceFactory); // Kernel registrations -constexpr std::array kAllXlaCpuTypes = { - {DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL}}; +constexpr std::array kAllXlaCpuTypes = { + {DT_UINT8, DT_QUINT8, DT_INT8, DT_QINT8, DT_INT32, DT_QINT32, DT_INT64, + DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL}}; REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_CPU, XlaLocalLaunchOp, kAllXlaCpuTypes); +REGISTER_XLA_COMPILE_KERNEL(DEVICE_XLA_CPU, XlaCompileOp, kAllXlaCpuTypes); +REGISTER_XLA_RUN_KERNEL(DEVICE_XLA_CPU, XlaRunOp, kAllXlaCpuTypes); + REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_CPU, kAllXlaCpuTypes); } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc index 70e6d0be0f2cffe98fd77fddac5866789c411a51..0824c4644e3e5d8e1390b99f12de824bfcdfec24 100644 --- a/tensorflow/compiler/jit/xla_device.cc +++ b/tensorflow/compiler/jit/xla_device.cc @@ -148,10 +148,9 @@ Status DefaultPaddedShapeFn(const Tensor& tensor, xla::Shape* shape) { } const DeviceAttributes attrs = Device::BuildDeviceAttributes( - strings::StrCat(name_prefix, "/device:", device_name, ":", - device_ordinal), + absl::StrCat(name_prefix, "/device:", device_name, ":", device_ordinal), DeviceType(device_name), Bytes(16ULL << 30), DeviceLocality(), - strings::StrCat("device: ", device_name, " device")); + absl::StrCat("device: ", device_name, " device")); device->reset( new XlaDevice(options, attrs, device_ordinal, DeviceType(jit_device_name), @@ -185,14 +184,13 @@ const DeviceType& XlaDevice::Metadata::jit_device_type() const { return device_type_; } -/* static */ Status XlaDevice::GetMetadata(OpKernelContext* ctx, - const Metadata** metadata) { +/*static*/ Status XlaDevice::GetMetadataFromDevice( + DeviceBase* device, const XlaDevice::Metadata** metadata) { *metadata = nullptr; - XlaDevice* xla_device = - dynamic_cast(ctx->device()->UnderlyingDevice()); + XlaDevice* xla_device = dynamic_cast(device->UnderlyingDevice()); if (xla_device == nullptr) { return errors::Internal( - "Cannot get XLA metadata from non-XLA device \"", ctx->device()->name(), + "Cannot get XLA metadata from non-XLA device \"", device->name(), "\". GetMetadata must only be called on an XLA device. Either an " "internal bug has been triggered, or an XLA-specific op has been " "placed on the wrong device."); @@ -201,6 +199,16 @@ const DeviceType& XlaDevice::Metadata::jit_device_type() const { return Status::OK(); } +/* static */ Status XlaDevice::GetMetadata(OpKernelContext* ctx, + const Metadata** metadata) { + return GetMetadataFromDevice(ctx->device(), metadata); +} + +/* static */ Status XlaDevice::GetMetadata(OpKernelConstruction* ctx, + const Metadata** metadata) { + return GetMetadataFromDevice(ctx->device(), metadata); +} + XlaDevice::XlaDevice( const SessionOptions& options, const DeviceAttributes& attrs, int device_ordinal, const DeviceType& jit_device_name, @@ -365,10 +373,6 @@ Status XlaDevice::FillContextMap(const Graph* graph, void XlaDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) { VLOG(2) << "XlaDevice::Compute " << op_kernel->name() << ":" << op_kernel->type_string(); - // When Xprof profiling is off (which is the default), constructing the - // activity is simple enough that its overhead is negligible. - tracing::ScopedActivity activity(op_kernel->name(), op_kernel->type_string(), - op_kernel->IsExpensive()); op_kernel->Compute(context); } @@ -430,6 +434,16 @@ Status XlaDevice::MakeTensorFromProto(const TensorProto& tensor_proto, return status; } +void XlaDevice::SetRequiresSyncOnCompletion(bool sync_on_completion) { + mutex_lock lock(mu_); + sync_on_completion_ = sync_on_completion; +} + +bool XlaDevice::RequiresSyncOnCompletion() const { + mutex_lock lock(mu_); + return sync_on_completion_; +} + XlaDeviceOpRegistrations* RegisterXlaDeviceKernels(const char* device, const char* jit_device) { // Any op assigned to the device that isn't rewritten by the graph rewriter diff --git a/tensorflow/compiler/jit/xla_device.h b/tensorflow/compiler/jit/xla_device.h index dbf35f349f84268ebac0f73a86c9ca0704e90835..0f06b3fc80b7c844dae5643127bdabba8a53b35e 100644 --- a/tensorflow/compiler/jit/xla_device.h +++ b/tensorflow/compiler/jit/xla_device.h @@ -88,6 +88,10 @@ class XlaDevice : public LocalDevice { // Sets `*metadata` to the XlaDevice Metadata in the XLA device used by `ctx`. static Status GetMetadata(OpKernelContext* ctx, const Metadata** metadata); + // Sets `*metadata` to the XlaDevice Metadata in the XLA device used by `ctx`. + static Status GetMetadata(OpKernelConstruction* ctx, + const Metadata** metadata); + // Factory function. 'platform_name' is the name of the XLA platform. // 'device_name' is the name of the Tensorflow device to create. // 'jit_device_name' is the name of the corresponding JIT device. @@ -147,6 +151,12 @@ class XlaDevice : public LocalDevice { // information for GPU and TPU devices. Status UseGpuDeviceInfo() LOCKS_EXCLUDED(mu_); + // Instructs this XlaDevice to return 'sync_on_completion' for + // RequiresSyncOnCompletion(). + void SetRequiresSyncOnCompletion(bool sync_on_completion) LOCKS_EXCLUDED(mu_); + + bool RequiresSyncOnCompletion() const override LOCKS_EXCLUDED(mu_); + private: xla::LocalClient* client() const; Allocator* GetAllocatorLocked(AllocatorAttributes attr) @@ -158,7 +168,10 @@ class XlaDevice : public LocalDevice { xla::StatusOr GetDeviceContextLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_); - mutex mu_; + static Status GetMetadataFromDevice(DeviceBase* device, + const XlaDevice::Metadata** metadata); + + mutable mutex mu_; // The metadata of this XlaDevice. const Metadata xla_metadata_; // Which hardware device in the client's platform this XlaDevice controls. @@ -200,6 +213,10 @@ class XlaDevice : public LocalDevice { // Thread pool used for running closures std::unique_ptr thread_pool_; + + // True if the device requires XlaDevice::Sync to be called on completion + // regardless of status. + bool sync_on_completion_ GUARDED_BY(mu_) = false; }; // Builds OpKernel registrations on 'device' for the JIT operators diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc index ee07c5c9643ef1119b9077326c1cf7c83930e90c..af83c792e5e11d8596c521c6a3aed332a1f42e5b 100644 --- a/tensorflow/compiler/jit/xla_device_context.cc +++ b/tensorflow/compiler/jit/xla_device_context.cc @@ -203,7 +203,7 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor, } void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor, - StringPiece tensor_name, + absl::string_view tensor_name, Device* device, Tensor* cpu_tensor, StatusCallback done) { @@ -339,7 +339,7 @@ void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor, } void XlaDeviceContext::CopyDeviceTensorToCPU(const Tensor* device_tensor, - StringPiece tensor_name, + absl::string_view tensor_name, Device* device, Tensor* cpu_tensor, StatusCallback done) { manager_.CopyDeviceTensorToCPU(device_tensor, tensor_name, device, cpu_tensor, diff --git a/tensorflow/compiler/jit/xla_device_context.h b/tensorflow/compiler/jit/xla_device_context.h index 2e7445340cbaf788bfd06260f4376596895231c1..df824212948ac96a5df5228cecd9a8c864bbec9a 100644 --- a/tensorflow/compiler/jit/xla_device_context.h +++ b/tensorflow/compiler/jit/xla_device_context.h @@ -57,7 +57,7 @@ class XlaTransferManager { void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device, Tensor* device_tensor, StatusCallback done) const; void CopyDeviceTensorToCPU(const Tensor* device_tensor, - StringPiece tensor_name, Device* device, + absl::string_view tensor_name, Device* device, Tensor* cpu_tensor, StatusCallback done); void CopyDeviceTensorToDevice(const Tensor& src_tensor, Tensor* dst_tensor, @@ -111,7 +111,7 @@ class XlaDeviceContext : public DeviceContext { Tensor* device_tensor, StatusCallback done) const override; void CopyDeviceTensorToCPU(const Tensor* device_tensor, - StringPiece tensor_name, Device* device, + absl::string_view tensor_name, Device* device, Tensor* cpu_tensor, StatusCallback done) override; void CopyDeviceTensorToDevice(const Tensor& src_tensor, Tensor* dst_tensor, const StatusCallback& done); diff --git a/tensorflow/compiler/jit/xla_device_ops.h b/tensorflow/compiler/jit/xla_device_ops.h index 13da5d2f948df671df6d0d80687321eaaa923943..14a232b7a8a41f6b4401b2f9de58623af9b1205e 100644 --- a/tensorflow/compiler/jit/xla_device_ops.h +++ b/tensorflow/compiler/jit/xla_device_ops.h @@ -65,6 +65,18 @@ class XlaAssignVariableOp : public AsyncOpKernel { .HostMemory("resources"), \ KERNEL); +#define REGISTER_XLA_COMPILE_KERNEL(DEVICE, KERNEL, TYPES) \ + REGISTER_KERNEL_BUILDER(Name("_XlaCompile") \ + .Device(DEVICE) \ + .HostMemory("constants") \ + .HostMemory("key") \ + .HostMemory("compilation_successful") \ + .HostMemory("resources"), \ + KERNEL); + +#define REGISTER_XLA_RUN_KERNEL(DEVICE, KERNEL, TYPES) \ + REGISTER_KERNEL_BUILDER(Name("_XlaRun").Device(DEVICE), KERNEL); + #define REGISTER_XLA_DEVICE_KERNELS(DEVICE, TYPES) \ REGISTER_KERNEL_BUILDER(Name("_Send").Device(DEVICE), SendOp); \ REGISTER_KERNEL_BUILDER(Name("_Recv").Device(DEVICE), RecvOp); \ @@ -89,9 +101,15 @@ class XlaAssignVariableOp : public AsyncOpKernel { REGISTER_KERNEL_BUILDER( \ Name("VarHandleOp").Device(DEVICE).HostMemory("resource"), \ ResourceHandleOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("_VarHandlesOp").Device(DEVICE).HostMemory("resources"), \ + ResourceHandlesOp); \ REGISTER_KERNEL_BUILDER( \ Name("ReadVariableOp").Device(DEVICE).HostMemory("resource"), \ ReadVariableOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("_ReadVariablesOp").Device(DEVICE).HostMemory("resources"), \ + ReadVariablesOp); \ REGISTER_KERNEL_BUILDER( \ Name("DestroyResourceOp").Device(DEVICE).HostMemory("resource"), \ DestroyResourceOp); \ @@ -198,33 +216,33 @@ class XlaAssignVariableOp : public AsyncOpKernel { \ REGISTER_KERNEL_BUILDER( \ Name("GeneratorDataset").Device(DEVICE).HostMemory("handle"), \ - GeneratorDatasetOp); \ + data::GeneratorDatasetOp); \ REGISTER_KERNEL_BUILDER(Name("PrefetchDataset") \ .Device(DEVICE) \ .HostMemory("buffer_size") \ .HostMemory("input_dataset") \ .HostMemory("handle"), \ - PrefetchDatasetOp); \ + data::PrefetchDatasetOp); \ \ REGISTER_KERNEL_BUILDER(Name("IteratorV2").Device(DEVICE), \ - IteratorHandleOp); \ + data::IteratorHandleOp); \ REGISTER_KERNEL_BUILDER( \ Name("MakeIterator").Device(DEVICE).HostMemory("dataset"), \ - MakeIteratorOp); \ + data::MakeIteratorOp); \ REGISTER_KERNEL_BUILDER(Name("AnonymousIterator").Device(DEVICE), \ - AnonymousIteratorHandleOp); \ + data::AnonymousIteratorHandleOp); \ REGISTER_KERNEL_BUILDER(Name("IteratorGetNext").Device(DEVICE), \ - IteratorGetNextOp); \ + data::IteratorGetNextOp); \ REGISTER_KERNEL_BUILDER(Name("IteratorGetNextSync").Device(DEVICE), \ - IteratorGetNextSyncOp); \ + data::IteratorGetNextSyncOp); \ REGISTER_KERNEL_BUILDER(Name("IteratorToStringHandle") \ .Device(DEVICE) \ .HostMemory("string_handle"), \ - IteratorToStringHandleOp); \ + data::IteratorToStringHandleOp); \ REGISTER_KERNEL_BUILDER(Name("IteratorFromStringHandleV2") \ .Device(DEVICE) \ .HostMemory("string_handle"), \ - IteratorFromStringHandleOp); \ + data::IteratorFromStringHandleOp); \ REGISTER_KERNEL_BUILDER(Name(FunctionLibraryDefinition::kArgOp) \ .Device(DEVICE) \ .HostMemory("output") \ diff --git a/tensorflow/compiler/jit/xla_fusion_optimizer.cc b/tensorflow/compiler/jit/xla_fusion_optimizer.cc index 915c5afa79b919f9a9c2a087026a7f85f59e5f11..bc0db558d8d0b7c666efcfac5c4926144b830380 100644 --- a/tensorflow/compiler/jit/xla_fusion_optimizer.cc +++ b/tensorflow/compiler/jit/xla_fusion_optimizer.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/jit/deadness_analysis.h" #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/graphcycles/graphcycles.h" @@ -41,8 +42,8 @@ static bool IsShapeConsumerOp(const Node& node) { } // Returns true if the op can be decomposed into XLA ops for which -// there are fusable elemental implementations. -bool IsXlaFusable(const NodeDef& node) { +// there are fusible elemental implementations. +static bool IsXlaFusible(const NodeDef& node) { static const std::unordered_set* elementwise_ops = new std::unordered_set( {// tf2xla/kernels/aggregate_ops.cc @@ -176,9 +177,9 @@ Status XlaFusionOptimizer::Optimize(grappler::Cluster* cluster, TF_RETURN_IF_ERROR(DeviceToDeviceType(node->def().device(), &device_type)); if (device_type.type_string().find("XLA") != string::npos) continue; - // Assume all fusable ops are registered. + // Assume all fusible ops are registered. // TODO(hpucha): Check for registration if possible. - if (!IsXlaFusable(node->def())) { + if (!IsXlaFusible(node->def())) { continue; } @@ -326,7 +327,7 @@ Status XlaFusionOptimizer::Optimize(grappler::Cluster* cluster, string& name = cluster_names[cluster]; if (name.empty()) { - name = strings::StrCat("cluster_", cluster_sequence_num++); + name = absl::StrCat("cluster_", cluster_sequence_num++); } n->AddAttr(kXlaClusterAttr, name); VLOG(3) << "Assigning node " << n->name() << " to cluster " << name; diff --git a/tensorflow/compiler/jit/xla_fusion_optimizer_test.cc b/tensorflow/compiler/jit/xla_fusion_optimizer_test.cc index b77b207908f8612bc8bba011645c3ac98de9de0e..68e19c8a135735a79fcabf121e619157fa22b4d8 100644 --- a/tensorflow/compiler/jit/xla_fusion_optimizer_test.cc +++ b/tensorflow/compiler/jit/xla_fusion_optimizer_test.cc @@ -73,7 +73,7 @@ TEST_F(XlaFusionOptimizerTest, Chains) { EXPECT_TRUE(clusters.find("D") == clusters.cend()); } -TEST_F(XlaFusionOptimizerTest, FusableOps) { +TEST_F(XlaFusionOptimizerTest, FusibleOps) { GraphDef graph; { GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); diff --git a/tensorflow/compiler/jit/xla_gpu_device.cc b/tensorflow/compiler/jit/xla_gpu_device.cc index ef4466f0056ea98adc1ae6774105466af0d14293..60979556a3245f4a9984cde889835ce31154fe18 100644 --- a/tensorflow/compiler/jit/xla_gpu_device.cc +++ b/tensorflow/compiler/jit/xla_gpu_device.cc @@ -16,7 +16,7 @@ limitations under the License. // Registers the XLA_GPU device, which is an XlaDevice instantiation that runs // operators using XLA via the XLA "CUDA" (GPU) backend. -#include "tensorflow/compiler/jit/kernels/xla_launch_op.h" +#include "tensorflow/compiler/jit/kernels/xla_ops.h" #include "tensorflow/compiler/jit/xla_device.h" #include "tensorflow/compiler/jit/xla_device_ops.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" @@ -74,11 +74,14 @@ REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_XLA_GPU, XlaGpuDeviceFactory); // Kernel registrations -constexpr std::array kAllXlaGpuTypes = { - {DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL, - DT_BFLOAT16}}; +constexpr std::array kAllXlaGpuTypes = { + {DT_UINT8, DT_QUINT8, DT_INT8, DT_QINT8, DT_INT32, DT_QINT32, DT_INT64, + DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL, DT_BFLOAT16}}; REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_GPU, XlaLocalLaunchOp, kAllXlaGpuTypes); +REGISTER_XLA_COMPILE_KERNEL(DEVICE_XLA_GPU, XlaCompileOp, kAllXlaGpuTypes); +REGISTER_XLA_RUN_KERNEL(DEVICE_XLA_GPU, XlaRunOp, kAllXlaGpuTypes); + REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_GPU, kAllXlaGpuTypes); } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_interpreter_device.cc b/tensorflow/compiler/jit/xla_interpreter_device.cc index 45745596749207189c60ee1e3dcf19b6ecb7eb5b..8a80639b6391ba9b73fe3143df8f6e44505cec2c 100644 --- a/tensorflow/compiler/jit/xla_interpreter_device.cc +++ b/tensorflow/compiler/jit/xla_interpreter_device.cc @@ -15,7 +15,7 @@ limitations under the License. // Registers the XLA_INTERPRETER device which exposes the XLA Interpreter. -#include "tensorflow/compiler/jit/kernels/xla_launch_op.h" +#include "tensorflow/compiler/jit/kernels/xla_ops.h" #include "tensorflow/compiler/jit/xla_device.h" #include "tensorflow/compiler/jit/xla_device_ops.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" @@ -25,8 +25,9 @@ namespace tensorflow { const char* const DEVICE_XLA_INTERPRETER = "XLA_INTERPRETER"; const char* const DEVICE_INTERPRETER_XLA_JIT = "XLA_INTERPRETER_JIT"; -constexpr std::array kExecAllTypes = { - {DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL}}; +constexpr std::array kExecAllTypes = { + {DT_INT8, DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, + DT_BOOL, DT_BFLOAT16}}; class XlaInterpreterDeviceFactory : public DeviceFactory { public: @@ -72,6 +73,10 @@ static bool OpFilter(KernelDef* kdef) { return true; } REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_INTERPRETER, XlaLocalLaunchOp, kExecAllTypes); +REGISTER_XLA_COMPILE_KERNEL(DEVICE_XLA_INTERPRETER, XlaCompileOp, + kExecAllTypes); +REGISTER_XLA_RUN_KERNEL(DEVICE_XLA_INTERPRETER, XlaRunOp, kExecAllTypes); + REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_INTERPRETER, kExecAllTypes); REGISTER_XLA_BACKEND(DEVICE_INTERPRETER_XLA_JIT, kExecAllTypes, OpFilter); diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc index 2ffce9298d99e1e136e15e9a4b0e3f5b26121bd5..4f6fc4e068e3ba125ddbca264c1affa1f09f5896 100644 --- a/tensorflow/compiler/jit/xla_launch_util.cc +++ b/tensorflow/compiler/jit/xla_launch_util.cc @@ -42,13 +42,14 @@ using xla::ShapedBuffer; } // anonymous namespace std::map SnapshotResourceVariables( - OpKernelContext* ctx, const std::vector& variables) { + OpKernelContext* ctx, absl::Span variables) { std::map snapshot; for (int i : variables) { Var* variable = nullptr; ResourceHandle handle = HandleFromInput(ctx, i); OptionalTensor& tensor = snapshot[i]; if (LookupResource(ctx, handle, &variable).ok()) { + core::ScopedUnref scoped_unref(variable); tf_shared_lock lock(*variable->mu()); tensor.name = handle.name(); tensor.present = true; @@ -133,7 +134,8 @@ XlaComputationLaunchContext::XlaComputationLaunchContext( void XlaComputationLaunchContext::PopulateInputs( OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel, - const std::map& variables) { + const std::map& variables, + int missing_ctx_input_prefix) { se::Stream* stream = ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr; // Build ShapedBuffers that point directly to the Tensor buffers. @@ -145,12 +147,13 @@ void XlaComputationLaunchContext::PopulateInputs( const Tensor* t; for (int i = 0; i < kernel->xla_input_shapes.size(); ++i) { int arg_num = kernel->input_mapping[i]; + DCHECK_GE(arg_num, missing_ctx_input_prefix); const xla::Shape& shape = kernel->xla_input_shapes[i]; if (variables.count(arg_num)) { t = &(variables.at(arg_num).value); CHECK(t); } else { - t = &(ctx->input(arg_num)); + t = &(ctx->input(arg_num - missing_ctx_input_prefix)); } if (use_multiple_streams_) { @@ -187,7 +190,7 @@ void XlaComputationLaunchContext::PopulateInputs( Status XlaComputationLaunchContext::PopulateOutputs( OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel, - ScopedShapedBuffer output) { + ScopedShapedBuffer output, int missing_ctx_input_prefix) { se::Stream* stream = ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr; @@ -271,31 +274,38 @@ Status XlaComputationLaunchContext::PopulateOutputs( } } else { const TensorShape& shape = kernel->outputs[i].shape; - VLOG(2) << "Retval " << i << " shape " << shape.DebugString(); - - se::DeviceMemoryBase buffer = output.buffer({output_num}); - if (allocate_xla_tensors_) { - Tensor* output_tensor; - TF_RETURN_IF_ERROR(ctx->allocate_output(i, shape, &output_tensor)); - XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor); - if (xla_tensor) { - xla_tensor->set_shaped_buffer(ScopedShapedBuffer( - ExtractSubShapedBuffer(&output, output_num, xla_allocator_))); - if (use_multiple_streams_) { - xla_tensor->SetDefinedOn(stream, definition_event); + const DataType& type = kernel->outputs[i].type; + VLOG(2) << "Retval " << i << " shape " << shape.DebugString() << " type " + << DataTypeString(type); + if (type == DT_RESOURCE) { + TF_RET_CHECK(kernel->outputs[i].input_index >= 0) + << "Invalid input for outputs " << i; + ctx->set_output(i, ctx->input(kernel->outputs[i].input_index)); + } else { + se::DeviceMemoryBase buffer = output.buffer({output_num}); + if (allocate_xla_tensors_) { + Tensor* output_tensor; + TF_RETURN_IF_ERROR(ctx->allocate_output(i, shape, &output_tensor)); + XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor); + if (xla_tensor) { + xla_tensor->set_shaped_buffer(ScopedShapedBuffer( + ExtractSubShapedBuffer(&output, output_num, xla_allocator_))); + if (use_multiple_streams_) { + xla_tensor->SetDefinedOn(stream, definition_event); + } + } else { + // xla_tensor wasn't valid, which must mean this is a zero-element + // tensor. + CHECK_EQ(output_tensor->TotalBytes(), 0); } } else { - // xla_tensor wasn't valid, which must mean this is a zero-element - // tensor. - CHECK_EQ(output_tensor->TotalBytes(), 0); + Tensor output_tensor = XlaTensorBuffer::MakeTensor( + ctx->expected_output_dtype(i), shape, buffer, allocator); + output.set_buffer(xla::OwningDeviceMemory(), {output_num}); + ctx->set_output(i, output_tensor); } - } else { - Tensor output_tensor = XlaTensorBuffer::MakeTensor( - ctx->expected_output_dtype(i), shape, buffer, allocator); - output.set_buffer(xla::OwningDeviceMemory(), {output_num}); - ctx->set_output(i, output_tensor); + ++output_num; } - ++output_num; } if (VLOG_IS_ON(3)) { @@ -308,7 +318,8 @@ Status XlaComputationLaunchContext::PopulateOutputs( for (int i = 0; i < kernel->resource_updates.size(); ++i) { Allocator* allocator = ctx->device()->GetAllocator({}); const XlaCompiler::ResourceUpdate& write = kernel->resource_updates[i]; - if (write.input_index < 0 || write.input_index >= ctx->num_inputs()) { + int actual_input_index = write.input_index - missing_ctx_input_prefix; + if (actual_input_index < 0 || actual_input_index >= ctx->num_inputs()) { return errors::Internal("Invalid input index for variable write."); } @@ -318,7 +329,7 @@ Status XlaComputationLaunchContext::PopulateOutputs( // TODO(b/35625933): tensorflow::Var should contain a PersistentTensor, // not a Tensor. TF_RETURN_IF_ERROR(LookupOrCreateResource( - ctx, HandleFromInput(ctx, write.input_index), &variable, + ctx, HandleFromInput(ctx, actual_input_index), &variable, [&write](Var** ptr) { *ptr = new Var(write.type); return Status::OK(); diff --git a/tensorflow/compiler/jit/xla_launch_util.h b/tensorflow/compiler/jit/xla_launch_util.h index 7ac275fab833400b90ced0180192845c9be30534..326d70a027564343408df356833c97e131495da0 100644 --- a/tensorflow/compiler/jit/xla_launch_util.h +++ b/tensorflow/compiler/jit/xla_launch_util.h @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/core/framework/types.h" #include "tensorflow/core/kernels/variable_ops.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/array_slice.h" namespace tensorflow { class XlaAllocator; @@ -43,7 +44,7 @@ class XlaAllocator; // resource variable is not initialized, the corresponding OptionalTensor // will have its `present` field set to false. std::map SnapshotResourceVariables( - OpKernelContext* ctx, const std::vector& variables); + OpKernelContext* ctx, absl::Span variables); // Adapter class that wraps a Tensorflow allocator as an XLA allocator. // Assumes that the Tensorflow allocator permits asynchronous deallocation: @@ -88,14 +89,24 @@ class XlaComputationLaunchContext { // Add all inputs within `ctx` as XLA arguments (returned by arguments()). // `variables` is a map from TensorFlow argument number to resource variable. + // + // Assumes that the first `missing_ctx_input_prefix` inputs to the kernel are + // missing and adjusts input indices accordingly. All elements in kernel's + // input_mapping must be greater than or equal to `missing_ctx_input_prefix` + // (in other words, no inputs actually required by the kernel can be missing). void PopulateInputs(OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel, - const std::map& variables); + const std::map& variables, + int missing_ctx_input_prefix); // Given the XLA output in `output`, populate all outputs of `ctx`. + // + // Assumes that the first `missing_ctx_input_prefix` inputs to the kernel are + // missing and adjusts input indices accordingly. Status PopulateOutputs(OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel, - xla::ScopedShapedBuffer output); + xla::ScopedShapedBuffer output, + int missing_ctx_input_prefix); // Return the argument list. Only valid after PopulateInputs() has been // called. diff --git a/tensorflow/compiler/jit/xla_tensor.h b/tensorflow/compiler/jit/xla_tensor.h index 4c9bb2e27b0ca3c83848be7fdf189fdbad89cee5..d95da63405889dfd0c279b17789a2195072c7277 100644 --- a/tensorflow/compiler/jit/xla_tensor.h +++ b/tensorflow/compiler/jit/xla_tensor.h @@ -122,7 +122,7 @@ class XlaTensor { std::shared_ptr definition_event_; // A list of all streams for which the tensor's content is defined for any // newly enqueued command. - gtl::InlinedVector streams_defined_on_ GUARDED_BY(mu_); + absl::InlinedVector streams_defined_on_ GUARDED_BY(mu_); mutex mu_; }; diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 94e08b6efe99fce73243c4e22bdd7565bdea6ef7..a8a9f39e10620499237c77883925a0223298a2b4 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -72,7 +72,7 @@ py_test( tf_xla_py_test( name = "adadelta_test", - size = "medium", + size = "large", srcs = ["adadelta_test.py"], deps = [ ":xla_test", @@ -251,6 +251,7 @@ tf_xla_py_test( tf_xla_py_test( name = "matrix_triangular_solve_op_test", size = "small", + timeout = "moderate", srcs = ["matrix_triangular_solve_op_test.py"], tags = ["optonly"], deps = [ @@ -276,9 +277,10 @@ tf_xla_py_test( ], ) +# This test is large because occasionally the cpu test is long for testConcatLargeNumberOfTensors tf_xla_py_test( name = "concat_ops_test", - size = "medium", + size = "large", srcs = ["concat_ops_test.py"], deps = [ ":xla_test", @@ -572,6 +574,7 @@ tf_xla_py_test( tf_xla_py_test( name = "matrix_band_part_test", size = "medium", + timeout = "long", srcs = ["matrix_band_part_test.py"], tags = ["optonly"], deps = [ @@ -579,6 +582,7 @@ tf_xla_py_test( "//tensorflow/python:array_ops", "//tensorflow/python:framework", "//tensorflow/python:platform_test", + "@absl_py//absl/testing:parameterized", ], ) @@ -890,6 +894,22 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "tensor_list_ops_test", + size = "small", + srcs = ["tensor_list_ops_test.py"], + # TensorList ops are not implemented in the on-demand compilation model yet. + disabled_backends = "cpu_ondemand", + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework", + "//tensorflow/python:list_ops", + "//tensorflow/python:platform_test", + "//tensorflow/python/eager:function", + ], +) + tf_xla_py_test( name = "ternary_ops_test", size = "small", @@ -974,7 +994,7 @@ tf_xla_py_test( name = "gather_test", size = "medium", srcs = ["gather_test.py"], - tags = ["noasan"], # times out, http://b/78599043 + tags = ["optonly"], deps = [ ":xla_test", "//tensorflow/python:array_ops", @@ -1024,6 +1044,19 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "permute_test", + size = "small", + srcs = ["permute_test.py"], + deps = [ + "//tensorflow/compiler/tests:xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + "//tensorflow/python:nn_ops", + ], +) + tf_xla_py_test( name = "xla_device_test", size = "small", @@ -1056,6 +1089,7 @@ cuda_py_test( size = "medium", srcs = ["jit_test.py"], additional_deps = [ + ":test_utils", "//tensorflow/contrib/compiler:compiler_py", "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", @@ -1074,6 +1108,7 @@ cuda_py_test( size = "small", srcs = ["dense_layer_test.py"], additional_deps = [ + ":test_utils", "//tensorflow/contrib/compiler:compiler_py", "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", @@ -1101,6 +1136,8 @@ cc_library( "//tensorflow/core:test", "//tensorflow/core:testlib", "//tensorflow/core/kernels:ops_util", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/strings", ], ) @@ -1193,8 +1230,21 @@ tf_xla_py_test( ) tf_xla_py_test( - name = "xla_ops_test", + name = "quantized_ops_test", size = "small", + srcs = ["quantized_ops_test.py"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + ], +) + +tf_xla_py_test( + name = "xla_ops_test", + size = "medium", srcs = ["xla_ops_test.py"], disabled_backends = ["cpu_ondemand"], deps = [ diff --git a/tensorflow/compiler/tests/adam_test.py b/tensorflow/compiler/tests/adam_test.py index 0d2e4d029636577adc74784d9a8b3494b94dc67d..058576b3d4b695209952158769162bb24e7ccfce 100644 --- a/tensorflow/compiler/tests/adam_test.py +++ b/tensorflow/compiler/tests/adam_test.py @@ -22,6 +22,7 @@ import numpy as np from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variable_scope @@ -53,9 +54,9 @@ class AdamOptimizerTest(xla_test.XLATestCase): def testBasic(self): for dtype in self.float_types: # TODO: test fails for float16 due to excessive precision requirements. - if dtype == np.float16: + if dtype in [np.float16, dtypes.bfloat16.as_numpy_dtype]: continue - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): variable_scope.get_variable_scope().set_use_resource(True) # Initialize variables for numpy implementation. @@ -95,9 +96,9 @@ class AdamOptimizerTest(xla_test.XLATestCase): def testTensorLearningRate(self): for dtype in self.float_types: # TODO: test fails for float16 due to excessive precision requirements. - if dtype == np.float16: + if dtype in [np.float16, dtypes.bfloat16.as_numpy_dtype]: continue - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): variable_scope.get_variable_scope().set_use_resource(True) # Initialize variables for numpy implementation. @@ -137,9 +138,9 @@ class AdamOptimizerTest(xla_test.XLATestCase): def testSharing(self): for dtype in self.float_types: # TODO: test fails for float16 due to excessive precision requirements. - if dtype == np.float16: + if dtype in [np.float16, dtypes.bfloat16.as_numpy_dtype]: continue - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): variable_scope.get_variable_scope().set_use_resource(True) # Initialize variables for numpy implementation. diff --git a/tensorflow/compiler/tests/argminmax_test.py b/tensorflow/compiler/tests/argminmax_test.py index 4155342787fbbdeaf5c5958c44d007b1ea0660ed..68f52e796c283997b71abcdb9c3bd6aa19cb06fc 100644 --- a/tensorflow/compiler/tests/argminmax_test.py +++ b/tensorflow/compiler/tests/argminmax_test.py @@ -50,12 +50,12 @@ class ArgMinMaxTest(xla_test.XLATestCase): def testArgMinMax(self): # Complex numbers do not support argmin/argmax. - minmax_types = set(self.numeric_types) - set(self.complex_types) + minmax_types = self.all_types & {np.int32, np.int64} for dtype in minmax_types: # output_type is a numpy data type that is used to specify the desired # output type of the op as well as to convert the Python number to the # array scalar of the type. - for output_type in self.int_types: + for output_type in minmax_types: self._assertOpOutputMatchesExpected( math_ops.argmax, axis=0, diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py index ed4940f204a032527a9926a71f5d99286ef18029..1b39d53dc0908e1fa05f766ca1e601731b26846d 100644 --- a/tensorflow/compiler/tests/binary_ops_test.py +++ b/tensorflow/compiler/tests/binary_ops_test.py @@ -210,7 +210,7 @@ class BinaryOpsTest(xla_test.XLATestCase): equality_test=self.ListsAreClose) def testIntOps(self): - for dtype in self.int_types: + for dtype in self.signed_int_types: self._testBinary( gen_math_ops.truncate_div, np.array([3, 3, -1, -9, -8], dtype=dtype), @@ -287,7 +287,8 @@ class BinaryOpsTest(xla_test.XLATestCase): dtype(7), expected=np.array([[-6], [-5]], dtype=dtype)) - if dtype not in self.complex_types: # min/max not supported for complex + # min/max not supported for complex + if dtype not in self.complex_types | {np.uint8, np.int8}: self._testBinary( math_ops.maximum, np.array([1, 2], dtype=dtype), @@ -337,7 +338,7 @@ class BinaryOpsTest(xla_test.XLATestCase): expected=np.array([[70], [14]], dtype=dtype)) # Complex support for squared_difference is incidental, see b/68205550 - if dtype not in self.complex_types: + if dtype not in self.complex_types | {np.uint8, np.int8}: self._testBinary( math_ops.squared_difference, np.array([1, 2], dtype=dtype), @@ -559,6 +560,13 @@ class BinaryOpsTest(xla_test.XLATestCase): dtype(2), expected=np.array([[5], [2]], dtype=dtype)) + if dtype in [np.float32, np.float64]: + nums = np.arange(-10, 10, .25, dtype=dtype).reshape(80, 1) + divs = np.arange(-3, 3, .25, dtype=dtype).reshape(1, 24) + np_result = np.true_divide(nums, divs) + np_result[:, divs[0] == 0] = 0 + self._testBinary(gen_math_ops.div_no_nan, nums, divs, expected=np_result) + if dtype not in self.complex_types: # floordiv unsupported for complex. self._testBinary( gen_math_ops.floor_div, @@ -567,7 +575,7 @@ class BinaryOpsTest(xla_test.XLATestCase): expected=np.array([1, -2, -1, -5, 2], dtype=dtype)) def testIntDivision(self): - for dtype in self.int_types: + for dtype in self.signed_int_types: self._testDivision(dtype) def testFloatDivision(self): @@ -588,7 +596,7 @@ class BinaryOpsTest(xla_test.XLATestCase): expected=np.array([1, 1, -1, 0], dtype=dtype)) def testIntRemainder(self): - for dtype in self.int_types: + for dtype in self.signed_int_types - {np.int8}: self._testRemainder(dtype) def testFloatRemainder(self): @@ -1010,7 +1018,38 @@ class BinaryOpsTest(xla_test.XLATestCase): [7, 7, 7, 7, 7, 7]], dtype=dtype)) - def testMirrorPad(self): + def testSymmetricMirrorPad(self): + mirror_pad = lambda t, paddings: array_ops.pad(t, paddings, "SYMMETRIC") + for dtype in self.numeric_types: + self._testBinary( + mirror_pad, + np.array( + [ + [1, 2, 3], # + [4, 5, 6], # + ], + dtype=dtype), + np.array([[ + 2, + 2, + ], [3, 3]], dtype=np.int32), + expected=np.array( + [ + [6, 5, 4, 4, 5, 6, 6, 5, 4], # + [3, 2, 1, 1, 2, 3, 3, 2, 1], # + [3, 2, 1, 1, 2, 3, 3, 2, 1], # + [6, 5, 4, 4, 5, 6, 6, 5, 4], # + [6, 5, 4, 4, 5, 6, 6, 5, 4], # + [3, 2, 1, 1, 2, 3, 3, 2, 1], # + ], + dtype=dtype)) + self._testBinary( + mirror_pad, + np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype), + np.array([[0, 0], [0, 0]], dtype=np.int32), + expected=np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype)) + + def testReflectMirrorPad(self): mirror_pad = lambda t, paddings: array_ops.pad(t, paddings, "REFLECT") for dtype in self.numeric_types: self._testBinary( @@ -1406,6 +1445,13 @@ class BinaryOpsTest(xla_test.XLATestCase): np.array([4, 0], dtype=np.int32), expected=np.zeros([4, 0], dtype=dtype)) + x = np.arange(3).reshape((3, 1, 1, 1)).astype(dtype) + self._testBinary( + array_ops.broadcast_to, + x, + np.array((3, 7, 8, 9), dtype=np.int32), + expected=np.tile(x, (1, 7, 8, 9))) + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/compiler/tests/build_defs.bzl b/tensorflow/compiler/tests/build_defs.bzl index 7b114d4f85d3a5cadc6af25b55c5a21f90d2a768..1d3979b21bfd915a641fabe1ef40301b3e5a17b4 100644 --- a/tensorflow/compiler/tests/build_defs.bzl +++ b/tensorflow/compiler/tests/build_defs.bzl @@ -2,90 +2,103 @@ load("@local_config_cuda//cuda:build_defs.bzl", "cuda_is_configured") load("//tensorflow/compiler/tests:plugin.bzl", "plugins") +load( + "//tensorflow/core:platform/default/build_config_root.bzl", + "tf_cuda_tests_tags", +) def all_backends(): - b = ["cpu"] + plugins.keys() - if cuda_is_configured(): - return b + ["gpu"] - else: - return b + b = ["cpu"] + plugins.keys() + if cuda_is_configured(): + return b + ["gpu"] + else: + return b -def tf_xla_py_test(name, srcs=[], deps=[], tags=[], data=[], main=None, - disabled_backends=None, **kwargs): - """Generates py_test targets, one per XLA backend. +def tf_xla_py_test( + name, + srcs = [], + deps = [], + tags = [], + data = [], + main = None, + disabled_backends = None, + **kwargs): + """Generates py_test targets, one per XLA backend. - This rule generates py_test() targets named name_backend, for each backend - in all_backends(). The rule also generates a test suite with named `name` that - tests all backends for the test. + This rule generates py_test() targets named name_backend, for each backend + in all_backends(). The rule also generates a test suite with named `name` that + tests all backends for the test. - For example, the following rule generates test cases foo_test_cpu, - foo_test_gpu, and a test suite name foo_test that tests both. - tf_xla_py_test( - name="foo_test", - srcs="foo_test.py", - deps=[...], - ) + For example, the following rule generates test cases foo_test_cpu, + foo_test_gpu, and a test suite name foo_test that tests both. + tf_xla_py_test( + name="foo_test", + srcs="foo_test.py", + deps=[...], + ) - Args: - name: Name of the target. - srcs: Sources for the target. - deps: Dependencies of the target. - tags: Tags to apply to the generated targets. - data: Data dependencies of the target. - main: Same as py_test's main attribute. - disabled_backends: A list of backends that should not be tested. Supported - values include "cpu" and "gpu". If not specified, defaults to None. - **kwargs: keyword arguments passed onto the generated py_test() rules. - """ - if disabled_backends == None: - disabled_backends = [] + Args: + name: Name of the target. + srcs: Sources for the target. + deps: Dependencies of the target. + tags: Tags to apply to the generated targets. + data: Data dependencies of the target. + main: Same as py_test's main attribute. + disabled_backends: A list of backends that should not be tested. Supported + values include "cpu" and "gpu". If not specified, defaults to None. + **kwargs: keyword arguments passed onto the generated py_test() rules. + """ + if disabled_backends == None: + disabled_backends = [] - enabled_backends = [b for b in all_backends() if b not in disabled_backends] - test_names = [] - for backend in enabled_backends: - test_name = "{}_{}".format(name, backend) - backend_tags = ["tf_xla_{}".format(backend)] - backend_args = [] - backend_deps = [] - backend_data = [] - if backend == "cpu": - backend_args += [ - "--test_device=XLA_CPU", - "--types=DT_HALF,DT_FLOAT,DT_DOUBLE,DT_INT32,DT_INT64,DT_BOOL,DT_COMPLEX64" - ] - elif backend == "gpu": - backend_args += [ - "--test_device=XLA_GPU", - "--types=DT_HALF,DT_FLOAT,DT_DOUBLE,DT_INT32,DT_INT64,DT_BOOL,DT_COMPLEX64,DT_BFLOAT16" - ] - backend_tags += ["requires-gpu-sm35"] - elif backend in plugins: - backend_args += ["--test_device=" + plugins[backend]["device"], - "--types=" + plugins[backend]["types"]] - backend_tags += plugins[backend]["tags"] - backend_args += plugins[backend]["args"] - backend_deps += plugins[backend]["deps"] - backend_data += plugins[backend]["data"] - else: - fail("Unknown backend {}".format(backend)) + enabled_backends = [b for b in all_backends() if b not in disabled_backends] + test_names = [] + for backend in enabled_backends: + test_name = "{}_{}".format(name, backend) + backend_tags = ["tf_xla_{}".format(backend)] + backend_args = [] + backend_deps = [] + backend_data = [] + if backend == "cpu": + backend_args += [ + "--test_device=XLA_CPU", + "--types=DT_HALF,DT_FLOAT,DT_DOUBLE,DT_UINT8,DT_QUINT8,DT_INT8,DT_QINT8,DT_INT32,DT_QINT32,DT_INT64,DT_BOOL,DT_COMPLEX64", + ] + elif backend == "gpu": + backend_args += [ + "--test_device=XLA_GPU", + "--types=DT_HALF,DT_FLOAT,DT_DOUBLE,DT_UINT8,DT_QUINT8,DT_INT8,DT_QINT8,DT_INT32,DT_QINT32,DT_INT64,DT_BOOL,DT_COMPLEX64,DT_BFLOAT16", + ] + backend_tags += tf_cuda_tests_tags() + elif backend in plugins: + backend_args += [ + "--test_device=" + plugins[backend]["device"], + "--types=" + plugins[backend]["types"], + ] + backend_tags += plugins[backend]["tags"] + backend_args += plugins[backend]["args"] + backend_deps += plugins[backend]["deps"] + backend_data += plugins[backend]["data"] + else: + fail("Unknown backend {}".format(backend)) - native.py_test( - name=test_name, - srcs=srcs, - srcs_version="PY2AND3", - args=backend_args, - main="{}.py".format(name) if main == None else main, - data=data + backend_data, - deps=deps + backend_deps, - tags=tags + backend_tags, - **kwargs - ) - test_names.append(test_name) - native.test_suite(name=name, tests=test_names) + native.py_test( + name = test_name, + srcs = srcs, + srcs_version = "PY2AND3", + args = backend_args, + main = "{}.py".format(name) if main == None else main, + data = data + backend_data, + deps = deps + backend_deps, + tags = tags + backend_tags, + **kwargs + ) + test_names.append(test_name) + native.test_suite(name = name, tests = test_names) -def generate_backend_suites(backends=[]): - """Generates per-backend test_suites that run all tests for a backend.""" - if not backends: - backends = all_backends() - for backend in backends: - native.test_suite(name="%s_tests" % backend, tags=["tf_xla_%s" % backend]) +def generate_backend_suites(backends = []): + """Generates per-backend test_suites that run all tests for a backend.""" + if not backends: + backends = all_backends() + for backend in backends: + native.test_suite(name = "%s_tests" % backend, tags = ["tf_xla_%s" % backend]) diff --git a/tensorflow/compiler/tests/concat_ops_test.py b/tensorflow/compiler/tests/concat_ops_test.py index 37e5318bb54c5d8ecdedc7bb346e89765f2adf35..2d225ad226cac368042b95eae8fc29e6fd8e82e0 100644 --- a/tensorflow/compiler/tests/concat_ops_test.py +++ b/tensorflow/compiler/tests/concat_ops_test.py @@ -291,6 +291,41 @@ class ConcatTest(xla_test.XLATestCase): ValueError, r"Can't concatenate scalars \(use tf\.stack instead\)"): array_ops.concat([scalar, scalar, scalar], dim) + # The purpose of this is to ensure that XLA on GPU will not run out of memory + # with too many arguments. + def testConcatLargeNumberOfTensors(self): + with self.cached_session(): + with self.test_scope(): + for concat_dim in range(2): + params = {} + p = [] + shape = np.array([7, 13]) + num_tensors = 1001 + for i in np.arange(num_tensors): + input_shape = shape + placeholder = array_ops.placeholder( + dtypes.float32, shape=input_shape) + p.append(placeholder) + params[placeholder] = np.random.rand(*input_shape).astype( + np.float32) + + concat_inputs = p + c = array_ops.concat(concat_inputs, concat_dim) + result = c.eval(feed_dict=params) + + self.assertEqual(result.shape, c.get_shape()) + cur_offset = 0 + + for i in np.arange(num_tensors): + # The index into the result is the ':' along all dimensions + # except the concat_dim. slice(0, size) is used for ':', and + # a list of slices is used to index into result. + index = [slice(0, params[p[i]].shape[j]) for j in np.arange(2)] + index[concat_dim] = slice( + cur_offset, cur_offset + params[p[i]].shape[concat_dim]) + cur_offset += params[p[i]].shape[concat_dim] + self.assertAllEqual(result[index], params[p[i]]) + class ConcatOffsetTest(xla_test.XLATestCase): diff --git a/tensorflow/compiler/tests/dense_layer_test.py b/tensorflow/compiler/tests/dense_layer_test.py index 04f3b3ef4905984b0432a536c3b1c275738ede17..d1b90f098d7d6574999ba0af44b285f5ad5e4f8d 100644 --- a/tensorflow/compiler/tests/dense_layer_test.py +++ b/tensorflow/compiler/tests/dense_layer_test.py @@ -21,6 +21,7 @@ from __future__ import print_function import os import numpy as np +from tensorflow.compiler.tests import test_utils from tensorflow.contrib.compiler import jit from tensorflow.core.protobuf import config_pb2 from tensorflow.python.layers import layers @@ -30,7 +31,6 @@ from tensorflow.python.platform import test jit_scope = jit.experimental_jit_scope - def GetRunMetadataLabels(run_metadata): """Returns all labels in run_metadata.""" labels = [] @@ -45,45 +45,51 @@ def InLabels(labels, substr): return any([substr in x for x in labels]) -def XlaLaunchOpCount(labels): - """Count how many XlaLaunch labels are present.""" - return sum("XlaLaunch(" in x for x in labels) +class DenseLayerTest(test.TestCase): + def countXlaOps(self, labels): + """Count how many XlaCompile/XlaRun labels are present.""" + xla_compile_count = sum("XlaCompile(" in x for x in labels) + xla_run_count = sum("XlaRun(" in x for x in labels) + self.assertEqual(xla_compile_count, xla_run_count) + return xla_run_count -class DenseLayerTest(test.TestCase): def testDenseLayerAutoJit(self): """Tests dense layer compilation in auto-jit mode. - Dense layer should be compiled into a single XlaLaunch op in auto-jit mode. + Dense layer should be compiled into a single XlaCompile/XlaRun op pair in + auto-jit mode. """ - os.environ["TF_XLA_FLAGS"] = ("--tf_xla_cpu_global_jit") + os.environ["TF_XLA_FLAGS"] = ( + "--tf_xla_cpu_global_jit " + os.environ.get("TF_XLA_FLAGS", "")) config = config_pb2.ConfigProto() config.graph_options.optimizer_options.global_jit_level = ( config_pb2.OptimizerOptions.ON_1) - with self.test_session(config=config) as sess: + with self.session(config=config) as sess: x = array_ops.placeholder(shape=[None, None, 3], dtype=np.float32) y = layers.dense(x, 3) sess.run(variables.initialize_all_variables()) run_metadata = config_pb2.RunMetadata() - sess.run( + test_utils.RunWithWarmup( + sess, y, {x: np.array([[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]])}, run_metadata=run_metadata, options=config_pb2.RunOptions( trace_level=config_pb2.RunOptions.FULL_TRACE)) labels = GetRunMetadataLabels(run_metadata) - self.assertEqual(1, XlaLaunchOpCount(labels)) - self.assertFalse(InLabels(labels, "ListDiff")) + self.assertEqual(1, self.countXlaOps(labels)) + self.assertFalse(InLabels(labels, "MatMult")) def testDenseLayerJitScopeDefinedShape(self): """Tests that the dense layer node is properly compiled in jit scope. Dense layer with static shape input tensor should be compiled into a single - XlaLaunch op by XLA. + XlaCompile/XlaRun op pair by XLA. """ with self.cached_session() as sess: @@ -93,14 +99,15 @@ class DenseLayerTest(test.TestCase): sess.run(variables.initialize_all_variables()) run_metadata = config_pb2.RunMetadata() - sess.run( + test_utils.RunWithWarmup( + sess, y, {x: np.array([[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]])}, run_metadata=run_metadata, options=config_pb2.RunOptions( trace_level=config_pb2.RunOptions.FULL_TRACE)) labels = GetRunMetadataLabels(run_metadata) - self.assertEqual(1, XlaLaunchOpCount(labels)) + self.assertEqual(1, self.countXlaOps(labels)) # No need to check whether ListDiff is compiled or not because ListDiff op # is not used when input tensor shape is fully defined. @@ -110,7 +117,8 @@ class DenseLayerTest(test.TestCase): Dense layer uses shape op to get shape of input tensor if its shape is not fully defined. XLA does not cluster shape op with other operators. But in experimental_jit_scope, XLA is forced to compile shape op into its own - cluster, causing dense layer to be split into TWO XlaLaunch ops. + cluster, causing dense layer to be split into TWO XlaCompile/XlaRun op + pairs. """ with self.cached_session() as sess: @@ -120,16 +128,19 @@ class DenseLayerTest(test.TestCase): sess.run(variables.initialize_all_variables()) run_metadata = config_pb2.RunMetadata() - sess.run( + test_utils.RunWithWarmup( + sess, y, {x: np.array([[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]])}, run_metadata=run_metadata, options=config_pb2.RunOptions( trace_level=config_pb2.RunOptions.FULL_TRACE)) labels = GetRunMetadataLabels(run_metadata) - self.assertEqual(2, XlaLaunchOpCount(labels)) - self.assertFalse(InLabels(labels, "ListDiff")) + self.assertEqual(2, self.countXlaOps(labels)) + self.assertFalse(InLabels(labels, "MatMult")) if __name__ == "__main__": + os.environ["TF_XLA_FLAGS"] = ("--tf_xla_enable_lazy_compilation=true " + + os.environ.get("TF_XLA_FLAGS", "")) test.main() diff --git a/tensorflow/compiler/tests/eager_test.py b/tensorflow/compiler/tests/eager_test.py index e32f3d4b7f5715a9dbe88ea241a643729dfb2a48..63cee550fde9d9d4314b1541fba191df776a4da2 100644 --- a/tensorflow/compiler/tests/eager_test.py +++ b/tensorflow/compiler/tests/eager_test.py @@ -351,6 +351,38 @@ class EagerFunctionTest(xla_test.XLATestCase): var = f(v) self.assertEqual(2.0, var.numpy()) + def testReturnResourceHandle(self): + with self.test_scope(): + v = resource_variable_ops.ResourceVariable([[1.0, 2.0], [3.0, 4.0]]) + + def f(v): + return v.handle + + f = function.defun(f) + handle = f(v) + self.assertAllEqual(v.numpy(), + resource_variable_ops.read_variable_op( + handle, dtypes.float32).numpy()) + + def testReturnMultipleResourceHandles(self): + with self.test_scope(): + v1 = resource_variable_ops.ResourceVariable(1.25) + v2 = resource_variable_ops.ResourceVariable(2.0) + + def f(v): + return v.handle, 3.0 * v, v2.handle, v + v2 + + f = function.defun(f) + v1_handle, v1_times_3, v2_handle, variable_sum = f(v1) + self.assertAllEqual(v1.numpy(), + resource_variable_ops.read_variable_op( + v1_handle, dtypes.float32).numpy()) + self.assertEqual(3.75, v1_times_3.numpy()) + self.assertAllEqual(v2.numpy(), + resource_variable_ops.read_variable_op( + v2_handle, dtypes.float32).numpy()) + self.assertEqual(3.25, variable_sum.numpy()) + def testAllArgumentKinds(self): """Test a complex function that takes different argument kinds. @@ -457,6 +489,72 @@ class EagerFunctionTest(xla_test.XLATestCase): y = two_x_plus_1(x) self.assertAllEqual([5, 7, 9], y.numpy()) + def testNestedDefunWithVariable(self): + with self.test_scope(): + v0 = resource_variable_ops.ResourceVariable(5.0) + + @function.defun + def g(x): + x = v0 * x + return x + + @function.defun + def f(x): + x = g(v0 * x) + return x + + x = constant_op.constant(3.0) + y = f(x) + + self.assertEqual(75, y.numpy()) + + def testNestedDefunInGradientTape(self): + with self.test_scope(): + v0 = resource_variable_ops.ResourceVariable(5.0) + + @function.defun + def g(x): + x = v0 * x + return x + + @function.defun + def f(x): + x = g(v0 * x) + return x + + x = constant_op.constant(3.0) + with backprop.GradientTape() as tape: + y = f(x) + dy = tape.gradient(y, v0) + + self.assertEqual(75, y.numpy()) + self.assertEqual(30, dy.numpy()) + + def testNestedDefunInGradientTapeDifferentVars(self): + with self.test_scope(): + v0 = resource_variable_ops.ResourceVariable(5.0) + v1 = resource_variable_ops.ResourceVariable(3.0) + + @function.defun + def g(x): + x = v1 * x + return x + + @function.defun + def f(x): + x = g(v0 * x) + return x + + x = constant_op.constant(3.0) + with backprop.GradientTape(persistent=True) as tape: + y = f(x) + dy_v0 = tape.gradient(y, v0) + dy_v1 = tape.gradient(y, v1) + + self.assertEqual(45, y.numpy()) + self.assertEqual(9, dy_v0.numpy()) + self.assertEqual(15, dy_v1.numpy()) + class ExcessivePaddingTest(xla_test.XLATestCase): """Test that eager execution works with TPU flattened tensors. diff --git a/tensorflow/compiler/tests/ftrl_test.py b/tensorflow/compiler/tests/ftrl_test.py index 7ca50b02d9bf3203cbd460c8de13a16defd974a3..f1b87a5ffb73bed62a80abaa152d335f64d970c5 100644 --- a/tensorflow/compiler/tests/ftrl_test.py +++ b/tensorflow/compiler/tests/ftrl_test.py @@ -29,7 +29,6 @@ from tensorflow.python.training import adagrad from tensorflow.python.training import ftrl from tensorflow.python.training import gradient_descent - class FtrlOptimizerTest(xla_test.XLATestCase): def initVariableAndGradient(self, dtype): @@ -196,7 +195,11 @@ class FtrlOptimizerTest(xla_test.XLATestCase): # Validate updated params self.assertAllCloseAccordingToType( - np.array([-7.66718769, -10.91273689]), var0.eval(), rtol=1e-4) + np.array([-7.66718769, -10.91273689]), + var0.eval(), + rtol=1e-4, + bfloat16_rtol=1e-1, + bfloat16_atol=1e-1) self.assertAllCloseAccordingToType( np.array([-0.93460727, -1.86147261]), var1.eval(), rtol=1e-4) @@ -259,9 +262,49 @@ class FtrlOptimizerTest(xla_test.XLATestCase): # Validate updated params self.assertAllCloseAccordingToType( - np.array([-0.21931979, -0.40642974]), var0.eval(), rtol=1e-4) + np.array([-0.22578996, -0.44345799]), var0.eval(), rtol=1e-4) self.assertAllCloseAccordingToType( - np.array([-0.0282721, -0.07188385]), var1.eval(), rtol=1e-4) + np.array([-0.14378493, -0.13229476]), var1.eval(), rtol=1e-4) + + def testFtrlWithL2ShrinkageDoesNotChangeLrSchedule(self): + """Verifies that l2 shrinkage in FTRL does not change lr schedule.""" + for dtype in self.float_types: + with self.test_session(), self.test_scope(): + var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) + var1 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) + grads0 = constant_op.constant([0.1, 0.2], dtype=dtype) + grads1 = constant_op.constant([0.1, 0.2], dtype=dtype) + + opt0 = ftrl.FtrlOptimizer( + 3.0, + initial_accumulator_value=0.1, + l1_regularization_strength=0.001, + l2_regularization_strength=2.0, + l2_shrinkage_regularization_strength=0.1) + opt1 = ftrl.FtrlOptimizer( + 3.0, + initial_accumulator_value=0.1, + l1_regularization_strength=0.001, + l2_regularization_strength=2.0) + update0 = opt0.apply_gradients([(grads0, var0)]) + update1 = opt1.apply_gradients([(grads1, var1)]) + variables.global_variables_initializer().run() + + self.assertAllCloseAccordingToType([1.0, 2.0], var0.eval()) + self.assertAllCloseAccordingToType([1.0, 2.0], var1.eval()) + + # Run 10 steps FTRL + for _ in range(10): + update0.run() + update1.run() + + # var0 is experiencing L2 shrinkage so it should be smaller than var1 + # in magnitude. + self.assertTrue((var0.eval()**2 < var1.eval()**2).all()) + accum0 = list(opt0._slots["accum"].values())[0].eval() + accum1 = list(opt1._slots["accum"].values())[0].eval() + # L2 shrinkage should not change how we update grad accumulator. + self.assertAllCloseAccordingToType(accum0, accum1) # When variables are initialized with Zero, FTRL-Proximal has two properties: # 1. Without L1&L2 but with fixed learning rate, FTRL-Proximal is identical diff --git a/tensorflow/compiler/tests/fused_batchnorm_test.py b/tensorflow/compiler/tests/fused_batchnorm_test.py index 8c018cccb83a05babb0b7f73b80b4f9de7267c98..374942a0b339b816944ea5529e4f84134b60017b 100644 --- a/tensorflow/compiler/tests/fused_batchnorm_test.py +++ b/tensorflow/compiler/tests/fused_batchnorm_test.py @@ -29,6 +29,11 @@ from tensorflow.python.ops import gradient_checker from tensorflow.python.ops import nn from tensorflow.python.platform import test +DATA_FORMATS = ( + ("_data_format_NHWC", "NHWC"), + ("_data_format_NCHW", "NCHW"), +) + class FusedBatchNormTest(xla_test.XLATestCase, parameterized.TestCase): @@ -65,12 +70,7 @@ class FusedBatchNormTest(xla_test.XLATestCase, parameterized.TestCase): grad_offset = np.sum(grad_y, axis=(0, 1, 2)) return grad_x, grad_scale, grad_offset - @parameterized.named_parameters( - ("_data_format_NHWC", "NHWC"), - ("_data_format_NCHW", "NCHW"), - ("_data_format_HWNC", "HWNC"), - ("_data_format_HWCN", "HWCN"), - ) + @parameterized.named_parameters(*DATA_FORMATS) def testInference(self, data_format): channel = 3 x_shape = [2, 2, 6, channel] @@ -170,30 +170,15 @@ class FusedBatchNormTest(xla_test.XLATestCase, parameterized.TestCase): self.assertAllClose(y_val, y_ref_converted, atol=1e-3) self.assertAllClose(var_val, var_ref, atol=1e-3) - @parameterized.named_parameters( - ("_data_format_NHWC", "NHWC"), - ("_data_format_NCHW", "NCHW"), - ("_data_format_HWNC", "HWNC"), - ("_data_format_HWCN", "HWCN"), - ) + @parameterized.named_parameters(*DATA_FORMATS) def testLearning(self, data_format): self._testLearning(False, data_format) - @parameterized.named_parameters( - ("_data_format_NHWC", "NHWC"), - ("_data_format_NCHW", "NCHW"), - ("_data_format_HWNC", "HWNC"), - ("_data_format_HWCN", "HWCN"), - ) + @parameterized.named_parameters(*DATA_FORMATS) def testLearningWithGradientChecker(self, data_format): self._testLearning(True, data_format) - @parameterized.named_parameters( - ("_data_format_NHWC", "NHWC"), - ("_data_format_NCHW", "NCHW"), - ("_data_format_HWNC", "HWNC"), - ("_data_format_HWCN", "HWCN"), - ) + @parameterized.named_parameters(*DATA_FORMATS) def testGradientTraining(self, data_format): # TODO(b/64270657): Use gradient_checker here in addition to comparing with # this reference implementation. @@ -241,12 +226,7 @@ class FusedBatchNormTest(xla_test.XLATestCase, parameterized.TestCase): self.assertAllClose(grad_scale_val, grad_scale_ref, atol=1e-2) self.assertAllClose(grad_offset_val, grad_offset_ref, atol=1e-3) - @parameterized.named_parameters( - ("_data_format_NHWC", "NHWC"), - ("_data_format_NCHW", "NCHW"), - ("_data_format_HWNC", "HWNC"), - ("_data_format_HWCN", "HWCN"), - ) + @parameterized.named_parameters(*DATA_FORMATS) def testGradientInference(self, data_format): # TODO(b/64270657): Use gradient_checker here in addition to comparing with # this reference implementation. diff --git a/tensorflow/compiler/tests/gather_test.py b/tensorflow/compiler/tests/gather_test.py index 089d95daab7e502b4ba13796fadc2ba3f209759b..a38e1edafe883f6d3b64e1d7f94e394cccafa2e9 100644 --- a/tensorflow/compiler/tests/gather_test.py +++ b/tensorflow/compiler/tests/gather_test.py @@ -51,7 +51,7 @@ class GatherTest(xla_test.XLATestCase): indices_tf = constant_op.constant(indices) gather_t = array_ops.gather(params, indices_tf) gather_val = session.run(gather_t, feed_dict={params: params_np}) - np_val = params_np[indices] + np_val = constant_op.constant(params_np[indices]) self.assertAllEqual(np_val, gather_val) def testScalar2D(self): @@ -65,7 +65,8 @@ class GatherTest(xla_test.XLATestCase): indices = constant_op.constant(2) gather_t = array_ops.gather(params, indices, axis=axis) gather_val = session.run(gather_t, feed_dict={params: params_np}) - expected = np.take(params_np, 2, axis=axis) + expected = constant_op.constant( + np.take(params_np, 2, axis=axis), dtype) self.assertAllEqual(expected, gather_val) def testSimpleTwoD32(self): @@ -80,7 +81,8 @@ class GatherTest(xla_test.XLATestCase): indices = constant_op.constant([0, 1, 0, 2]) gather_t = array_ops.gather(params, indices, axis=axis) gather_val = session.run(gather_t, feed_dict={params: params_np}) - expected = np.take(params_np, [0, 1, 0, 2], axis=axis) + expected = constant_op.constant( + np.take(params_np, [0, 1, 0, 2], axis=axis), dtype) self.assertAllEqual(expected, gather_val) def testSimpleTwoD32_Int64Indices(self): @@ -103,7 +105,8 @@ class GatherTest(xla_test.XLATestCase): params: params_np, indices: indices_np }) - expected = np.take(params_np, [0, 1, 0, 2], axis=axis) + expected = constant_op.constant( + np.take(params_np, [0, 1, 0, 2], axis=axis), dtype) self.assertAllEqual(expected, gather_val) def testHigherRank(self): @@ -119,7 +122,8 @@ class GatherTest(xla_test.XLATestCase): tf_indices = constant_op.constant(indices, dtype=dtypes.int32) gather = array_ops.gather(tf_params, tf_indices, axis=axis) gather_value = sess.run(gather, feed_dict={tf_params: params}) - gather_np = np.take(params, indices, axis=axis) + gather_np = constant_op.constant( + np.take(params, indices, axis=axis), dtype) self.assertAllEqual(gather_np, gather_value) def testIndicesWithDifferentDimensions(self): diff --git a/tensorflow/compiler/tests/image_ops_test.py b/tensorflow/compiler/tests/image_ops_test.py index 6fe5a66e0e6717ec738dded9196eef6ba1e2114d..d67b16f8e9e7320d5717b0203be340a2356e53d0 100644 --- a/tensorflow/compiler/tests/image_ops_test.py +++ b/tensorflow/compiler/tests/image_ops_test.py @@ -26,7 +26,6 @@ import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.compiler.tests import xla_test -from tensorflow.python.compat import compat from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -605,137 +604,205 @@ class ResizeBilinearTest(xla_test.XLATestCase): class NonMaxSuppressionTest(xla_test.XLATestCase): def testNMS128From1024(self): - # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU. - if self.device in ["XLA_CPU", "XLA_GPU"]: - return - - with compat.forward_compatibility_horizon(2018, 8, 8): - num_boxes = 1024 - boxes_np = np.random.normal(50, 10, (num_boxes, 4)).astype("f4") - scores_np = np.random.normal(0.5, 0.1, (num_boxes,)).astype("f4") - - max_output_size = 128 - iou_threshold_np = np.array(0.5, dtype=np.float32) - score_threshold_np = np.array(0.0, dtype=np.float32) - - with self.cached_session() as sess: - boxes = array_ops.placeholder(boxes_np.dtype, shape=boxes_np.shape) - scores = array_ops.placeholder(scores_np.dtype, shape=scores_np.shape) - iou_threshold = array_ops.placeholder(iou_threshold_np.dtype, - iou_threshold_np.shape) - score_threshold = array_ops.placeholder(score_threshold_np.dtype, - score_threshold_np.shape) - with self.test_scope(): - selected_indices = image_ops.non_max_suppression_padded( - boxes=boxes, - scores=scores, - max_output_size=max_output_size, - iou_threshold=iou_threshold, - score_threshold=score_threshold, - pad_to_max_output_size=True) - inputs_feed = { - boxes: boxes_np, - scores: scores_np, - score_threshold: score_threshold_np, - iou_threshold: iou_threshold_np - } - (indices_tf, _) = sess.run(selected_indices, feed_dict=inputs_feed) - - self.assertEqual(indices_tf.size, max_output_size) + num_boxes = 1024 + boxes_np = np.random.normal(50, 10, (num_boxes, 4)).astype("f4") + scores_np = np.random.normal(0.5, 0.1, (num_boxes,)).astype("f4") + + max_output_size = 128 + iou_threshold_np = np.array(0.5, dtype=np.float32) + score_threshold_np = np.array(0.0, dtype=np.float32) + + with self.cached_session() as sess: + boxes = array_ops.placeholder(boxes_np.dtype, shape=boxes_np.shape) + scores = array_ops.placeholder(scores_np.dtype, shape=scores_np.shape) + iou_threshold = array_ops.placeholder(iou_threshold_np.dtype, + iou_threshold_np.shape) + score_threshold = array_ops.placeholder(score_threshold_np.dtype, + score_threshold_np.shape) + with self.test_scope(): + selected_indices = image_ops.non_max_suppression_padded( + boxes=boxes, + scores=scores, + max_output_size=max_output_size, + iou_threshold=iou_threshold, + score_threshold=score_threshold, + pad_to_max_output_size=True) + inputs_feed = { + boxes: boxes_np, + scores: scores_np, + score_threshold: score_threshold_np, + iou_threshold: iou_threshold_np + } + (indices_tf, _) = sess.run(selected_indices, feed_dict=inputs_feed) + + self.assertEqual(indices_tf.size, max_output_size) def testNMS3From6Boxes(self): - # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU. - if self.device in ["XLA_CPU", "XLA_GPU"]: - return - - with compat.forward_compatibility_horizon(2018, 8, 8): - # Three boxes are selected based on IOU. - boxes_data = [[0, 0, 1, 1], [0, 0.1, 1, 1.1], [0, -0.1, 1, 0.9], - [0, 10, 1, 11], [0, 10.1, 1, 11.1], [0, 100, 1, 101]] - boxes_np = np.array(boxes_data, dtype=np.float32) - - scores_data = [0.9, 0.75, 0.6, 0.95, 0.5, 0.3] - scores_np = np.array(scores_data, dtype=np.float32) - - max_output_size = 3 - iou_threshold_np = np.array(0.5, dtype=np.float32) - score_threshold_np = np.array(0.0, dtype=np.float32) - - with self.cached_session() as sess: - boxes = array_ops.placeholder(boxes_np.dtype, shape=boxes_np.shape) - scores = array_ops.placeholder(scores_np.dtype, shape=scores_np.shape) - iou_threshold = array_ops.placeholder(iou_threshold_np.dtype, - iou_threshold_np.shape) - score_threshold = array_ops.placeholder(score_threshold_np.dtype, - score_threshold_np.shape) - with self.test_scope(): - selected_indices = image_ops.non_max_suppression_padded( - boxes=boxes, - scores=scores, - max_output_size=max_output_size, - iou_threshold=iou_threshold, - score_threshold=score_threshold, - pad_to_max_output_size=True) - inputs_feed = { - boxes: boxes_np, - scores: scores_np, - score_threshold: score_threshold_np, - iou_threshold: iou_threshold_np - } - (indices_tf, num_valid) = sess.run( - selected_indices, feed_dict=inputs_feed) - - self.assertEqual(indices_tf.size, max_output_size) - self.assertEqual(num_valid, 3) - self.assertAllClose(indices_tf[:num_valid], [3, 0, 5]) + # Three boxes are selected based on IOU. + boxes_data = [[0, 0, 1, 1], [0, 0.1, 1, 1.1], [0, -0.1, 1, 0.9], + [0, 10, 1, 11], [0, 10.1, 1, 11.1], [0, 100, 1, 101]] + boxes_np = np.array(boxes_data, dtype=np.float32) + + scores_data = [0.9, 0.75, 0.6, 0.95, 0.5, 0.3] + scores_np = np.array(scores_data, dtype=np.float32) + + max_output_size = 3 + iou_threshold_np = np.array(0.5, dtype=np.float32) + score_threshold_np = np.array(0.0, dtype=np.float32) + + with self.cached_session() as sess: + boxes = array_ops.placeholder(boxes_np.dtype, shape=boxes_np.shape) + scores = array_ops.placeholder(scores_np.dtype, shape=scores_np.shape) + iou_threshold = array_ops.placeholder(iou_threshold_np.dtype, + iou_threshold_np.shape) + score_threshold = array_ops.placeholder(score_threshold_np.dtype, + score_threshold_np.shape) + with self.test_scope(): + selected_indices = image_ops.non_max_suppression_padded( + boxes=boxes, + scores=scores, + max_output_size=max_output_size, + iou_threshold=iou_threshold, + score_threshold=score_threshold, + pad_to_max_output_size=True) + inputs_feed = { + boxes: boxes_np, + scores: scores_np, + score_threshold: score_threshold_np, + iou_threshold: iou_threshold_np + } + (indices_tf, num_valid) = sess.run( + selected_indices, feed_dict=inputs_feed) + + self.assertEqual(indices_tf.size, max_output_size) + self.assertEqual(num_valid, 3) + self.assertAllClose(indices_tf[:num_valid], [3, 0, 5]) def testNMS3Then2WithScoreThresh(self): # Three boxes are selected based on IOU. # One is filtered out by score threshold. - # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU. - if self.device in ["XLA_CPU", "XLA_GPU"]: - return - - with compat.forward_compatibility_horizon(2018, 8, 8): - boxes_data = [[0, 0, 1, 1], [0, 0.1, 1, 1.1], [0, -0.1, 1, 0.9], - [0, 10, 1, 11], [0, 10.1, 1, 11.1], [0, 100, 1, 101]] - boxes_np = np.array(boxes_data, dtype=np.float32) - - scores_data = [0.9, 0.75, 0.6, 0.95, 0.5, 0.3] - scores_np = np.array(scores_data, dtype=np.float32) - max_output_size = 3 - iou_threshold_np = np.array(0.5, dtype=np.float32) - score_threshold_np = np.array(0.4, dtype=np.float32) - - with self.cached_session() as sess: - boxes = array_ops.placeholder(boxes_np.dtype, shape=boxes_np.shape) - scores = array_ops.placeholder(scores_np.dtype, shape=scores_np.shape) - iou_threshold = array_ops.placeholder(iou_threshold_np.dtype, - iou_threshold_np.shape) - score_threshold = array_ops.placeholder(score_threshold_np.dtype, - score_threshold_np.shape) - with self.test_scope(): - selected_indices = image_ops.non_max_suppression_padded( - boxes=boxes, - scores=scores, - max_output_size=max_output_size, - iou_threshold=iou_threshold, - score_threshold=score_threshold, - pad_to_max_output_size=True) - inputs_feed = { - boxes: boxes_np, - scores: scores_np, - iou_threshold: iou_threshold_np, - score_threshold: score_threshold_np - } - (indices_tf, num_valid) = sess.run( - selected_indices, feed_dict=inputs_feed) - - self.assertEqual(indices_tf.size, max_output_size) - self.assertEqual(num_valid, 2) - self.assertAllClose(indices_tf[:num_valid], [3, 0]) - + boxes_data = [[0, 0, 1, 1], [0, 0.1, 1, 1.1], [0, -0.1, 1, 0.9], + [0, 10, 1, 11], [0, 10.1, 1, 11.1], [0, 100, 1, 101]] + boxes_np = np.array(boxes_data, dtype=np.float32) + + scores_data = [0.9, 0.75, 0.6, 0.95, 0.5, 0.3] + scores_np = np.array(scores_data, dtype=np.float32) + max_output_size = 3 + iou_threshold_np = np.array(0.5, dtype=np.float32) + score_threshold_np = np.array(0.4, dtype=np.float32) + + with self.cached_session() as sess: + boxes = array_ops.placeholder(boxes_np.dtype, shape=boxes_np.shape) + scores = array_ops.placeholder(scores_np.dtype, shape=scores_np.shape) + iou_threshold = array_ops.placeholder(iou_threshold_np.dtype, + iou_threshold_np.shape) + score_threshold = array_ops.placeholder(score_threshold_np.dtype, + score_threshold_np.shape) + with self.test_scope(): + selected_indices = image_ops.non_max_suppression_padded( + boxes=boxes, + scores=scores, + max_output_size=max_output_size, + iou_threshold=iou_threshold, + score_threshold=score_threshold, + pad_to_max_output_size=True) + inputs_feed = { + boxes: boxes_np, + scores: scores_np, + iou_threshold: iou_threshold_np, + score_threshold: score_threshold_np + } + (indices_tf, num_valid) = sess.run( + selected_indices, feed_dict=inputs_feed) + + self.assertEqual(indices_tf.size, max_output_size) + self.assertEqual(num_valid, 2) + self.assertAllClose(indices_tf[:num_valid], [3, 0]) + + def testNMS3Then1WithScoreMaxThresh(self): + # Three boxes are selected based on IOU. + # One is filtered out by score threshold. + # One is filtered out by max_output_size. + + boxes_data = [[0, 0, 1, 1], [0, 0.1, 1, 1.1], [0, -0.1, 1, 0.9], + [0, 10, 1, 11], [0, 10.1, 1, 11.1], [0, 100, 1, 101]] + boxes_np = np.array(boxes_data, dtype=np.float32) + + scores_data = [0.9, 0.75, 0.6, 0.95, 0.5, 0.3] + scores_np = np.array(scores_data, dtype=np.float32) + max_output_size = 1 + iou_threshold_np = np.array(0.5, dtype=np.float32) + score_threshold_np = np.array(0.4, dtype=np.float32) + + with self.cached_session() as sess: + boxes = array_ops.placeholder(boxes_np.dtype, shape=boxes_np.shape) + scores = array_ops.placeholder(scores_np.dtype, shape=scores_np.shape) + iou_threshold = array_ops.placeholder(iou_threshold_np.dtype, + iou_threshold_np.shape) + score_threshold = array_ops.placeholder(score_threshold_np.dtype, + score_threshold_np.shape) + with self.test_scope(): + selected_indices = image_ops.non_max_suppression_padded( + boxes=boxes, + scores=scores, + max_output_size=max_output_size, + iou_threshold=iou_threshold, + score_threshold=score_threshold, + pad_to_max_output_size=True) + inputs_feed = { + boxes: boxes_np, + scores: scores_np, + iou_threshold: iou_threshold_np, + score_threshold: score_threshold_np + } + (indices_tf, num_valid) = sess.run( + selected_indices, feed_dict=inputs_feed) + + self.assertEqual(indices_tf.size, max_output_size) + self.assertEqual(num_valid, 1) + self.assertAllClose(indices_tf[:num_valid], [3]) + + def testSelectFromContinuousOverLap(self): + # Tests that a suppressed box does not itself suppress other boxes. + + boxes_data = [[0, 0, 1, 1], [0, 0.2, 1, 1.2], [0, 0.4, 1, 1.4], + [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 3]] + boxes_np = np.array(boxes_data, dtype=np.float32) + + scores_data = [0.9, 0.75, 0.6, 0.5, 0.4, 0.3] + scores_np = np.array(scores_data, dtype=np.float32) + max_output_size = 3 + iou_threshold_np = np.array(0.5, dtype=np.float32) + score_threshold_np = np.array(0.1, dtype=np.float32) + + with self.cached_session() as sess: + boxes = array_ops.placeholder(boxes_np.dtype, shape=boxes_np.shape) + scores = array_ops.placeholder(scores_np.dtype, shape=scores_np.shape) + iou_threshold = array_ops.placeholder(iou_threshold_np.dtype, + iou_threshold_np.shape) + score_threshold = array_ops.placeholder(score_threshold_np.dtype, + score_threshold_np.shape) + with self.test_scope(): + selected_indices = image_ops.non_max_suppression_padded( + boxes=boxes, + scores=scores, + max_output_size=max_output_size, + iou_threshold=iou_threshold, + score_threshold=score_threshold, + pad_to_max_output_size=True) + inputs_feed = { + boxes: boxes_np, + scores: scores_np, + iou_threshold: iou_threshold_np, + score_threshold: score_threshold_np + } + (indices_tf, num_valid) = sess.run( + selected_indices, feed_dict=inputs_feed) + + self.assertEqual(indices_tf.size, max_output_size) + self.assertEqual(num_valid, 3) + self.assertAllClose(indices_tf[:num_valid], [0, 2, 4]) if __name__ == "__main__": test.main() diff --git a/tensorflow/compiler/tests/jit_test.py b/tensorflow/compiler/tests/jit_test.py index 6e0db54b7a74b284dc7d18bcbb07c178c664c1e5..8778b54dfaf35003c83cf2ab03e9e218c60c98ed 100644 --- a/tensorflow/compiler/tests/jit_test.py +++ b/tensorflow/compiler/tests/jit_test.py @@ -21,6 +21,7 @@ from __future__ import print_function import os import numpy as np +from tensorflow.compiler.tests import test_utils from tensorflow.contrib.compiler import jit from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import rewriter_config_pb2 @@ -36,8 +37,8 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.platform import test -jit_scope = jit.experimental_jit_scope +jit_scope = jit.experimental_jit_scope # Disable rewrites to make sure we don't end up having to update this test # whenever we implement new ones. @@ -77,11 +78,11 @@ def InLabels(labels, substr): return any([substr in x for x in labels]) -def MetadataHasXlaLaunch(run_metadata): - """Returns true if there is a XlaLaunch kernel in run_metadata's timeline.""" +def MetadataHasXlaRunOp(run_metadata): + """Returns true if there are XlaRun kernels in run_metadata's timeline.""" # TODO(phawkins): find a less hacky way to test whether a kernel ran. - return InLabels(RunMetadataLabels(run_metadata), "XlaLaunch") + return InLabels(RunMetadataLabels(run_metadata), "_XlaRun") class JitLaunchTest(test.TestCase): @@ -90,9 +91,10 @@ class JitLaunchTest(test.TestCase): # Verifies that the outputs match and that XLA was invoked. 'fn' must take # the same number of tensors as arguments that are in 'args', and must return # a tuple of output tensors. - # If 'require_kernel_launch' is True, then we verify that a XlaLaunch node - # actually ran. However, it is sometimes possible for XlaLaunch ops to be - # constant-folded away, so the check is optional. + # + # If 'require_kernel_launch' is True, then we verify that an XlaCompile/XlaRun + # node actually ran. However, it is sometimes possible for XlaCompile/XlaRun + # ops to be constant-folded away, so the check is optional. def _compare(self, fn, args, require_kernel_launch=True, noinline=None): with session_lib.Session(config=NoRewriteSessionConfig()) as sess: placeholders = [] @@ -107,15 +109,14 @@ class JitLaunchTest(test.TestCase): direct_op = fn(*placeholders) run_metadata = config_pb2.RunMetadata() - compiled = sess.run(compiled_op, - feeds, - run_metadata=run_metadata, - options=config_pb2.RunOptions( - trace_level=config_pb2.RunOptions.FULL_TRACE)) + compiled = test_utils.RunWithWarmup( + sess, compiled_op, feeds, + config_pb2.RunOptions(trace_level=config_pb2.RunOptions.FULL_TRACE), + run_metadata) print("Compiled Result {}".format(compiled)) if require_kernel_launch: - self.assert_(MetadataHasXlaLaunch(run_metadata)) + self.assert_(MetadataHasXlaRunOp(run_metadata)) direct = sess.run(direct_op, feeds) print("Direct Result {}".format(direct)) @@ -136,7 +137,7 @@ class JitLaunchTest(test.TestCase): a = constant_op.constant(100) # pylint: disable=unused-variable call = KernelWithNoOutputs() # pylint: disable=assignment-from-no-return - sess.run(call, {}) + test_utils.RunWithWarmup(sess, call, {}) def testAliasing(self): """Regression test for compiled functions that return an aliased buffer. @@ -149,10 +150,10 @@ class JitLaunchTest(test.TestCase): y = math_ops.add(x, x) return y, y - # Exercises compling a function (say, Foo) which calls another - # function (say, Bar) which is not inlined. When the compiler compiles - # Foo, it needs to symbolic execute Bar correctly regardless whether - # Bar is inlined or not. + # Exercises compiling a function (say, Foo) which calls another function + # (say, Bar) which is not inlined. When the compiler compiles Foo, it needs + # to symbolically execute Bar correctly regardless of whether Bar is inlined + # or not. # TODO(b/36139787): Re-enable this test when noinline works again. # Tests compiled=True and noinline=True. @@ -249,17 +250,21 @@ class JitLaunchTest(test.TestCase): dx = np.random.random_sample((batch_size, image_size)).astype(np.float32) with session_lib.Session() as sess: run_metadata = config_pb2.RunMetadata() - output = sess.run(y, {x: dx, - w: dw, - b: db}, - run_metadata=run_metadata, - options=config_pb2.RunOptions( - trace_level=config_pb2.RunOptions.FULL_TRACE)) + output = test_utils.RunWithWarmup( + sess, + y, { + x: dx, + w: dw, + b: db + }, + run_metadata=run_metadata, + options=config_pb2.RunOptions( + trace_level=config_pb2.RunOptions.FULL_TRACE)) # TODO(phawkins): really we would like to test that there were exactly # two kernel launches. However, we have no reliable way to determine # that. - self.assert_(MetadataHasXlaLaunch(run_metadata)) + self.assert_(MetadataHasXlaRunOp(run_metadata)) expected = np.square(np.dot(dx, dw) + db) self.assertAllClose(expected, output, rtol=1e-1) @@ -271,7 +276,7 @@ class XlaCompilationTest(test.TestCase): def testReshape(self): """Tests an operator with compile-time constant and non-constant inputs.""" - with self.test_session(config=NoRewriteSessionConfig()) as sess: + with self.session(config=NoRewriteSessionConfig()) as sess: x = array_ops.placeholder(dtypes.float32) y = array_ops.placeholder(dtypes.int32) with jit_scope(): @@ -283,19 +288,22 @@ class XlaCompilationTest(test.TestCase): # statically known as part of the JIT compilation's input graph. z = array_ops.reshape(x, y) run_metadata = config_pb2.RunMetadata() - out = sess.run(z, - {x: np.array([1, 2, 3, 4, 5, 6], np.float32), - y: [-1, 3]}, - run_metadata=run_metadata, - options=config_pb2.RunOptions( - trace_level=config_pb2.RunOptions.FULL_TRACE)) - self.assert_(MetadataHasXlaLaunch(run_metadata)) + out = test_utils.RunWithWarmup( + sess, + z, { + x: np.array([1, 2, 3, 4, 5, 6], np.float32), + y: [-1, 3] + }, + run_metadata=run_metadata, + options=config_pb2.RunOptions( + trace_level=config_pb2.RunOptions.FULL_TRACE)) + self.assert_(MetadataHasXlaRunOp(run_metadata)) self.assertAllClose(np.array([[1, 2, 3], [4, 5, 6]], np.float32), out) def testIgnoredArguments(self): """Tests that JIT computations can ignore formal parameters.""" - with self.test_session(config=NoRewriteSessionConfig()) as sess: + with self.session(config=NoRewriteSessionConfig()) as sess: x = array_ops.placeholder(dtypes.int32) y = array_ops.placeholder(dtypes.int32) with jit_scope(): @@ -308,18 +316,22 @@ class XlaCompilationTest(test.TestCase): t = math_ops.add(z, z) run_metadata = config_pb2.RunMetadata() - out = sess.run(t, {x: np.int32(7), - y: np.int32(404)}, - run_metadata=run_metadata, - options=config_pb2.RunOptions( - trace_level=config_pb2.RunOptions.FULL_TRACE)) - self.assert_(MetadataHasXlaLaunch(run_metadata)) + out = test_utils.RunWithWarmup( + sess, + t, { + x: np.int32(7), + y: np.int32(404) + }, + run_metadata=run_metadata, + options=config_pb2.RunOptions( + trace_level=config_pb2.RunOptions.FULL_TRACE)) + self.assert_(MetadataHasXlaRunOp(run_metadata)) self.assertAllClose(28, out) def testLoops(self): """Tests that compilation accepts computations containing loops.""" - with self.test_session(config=NoRewriteSessionConfig()) as session: + with self.session(config=NoRewriteSessionConfig()) as session: x = array_ops.placeholder(dtypes.float32) with jit_scope(): c = lambda i, _: math_ops.less(i, 5) @@ -331,13 +343,13 @@ class XlaCompilationTest(test.TestCase): run_metadata=run_metadata, options=config_pb2.RunOptions( trace_level=config_pb2.RunOptions.FULL_TRACE)) - self.assert_(MetadataHasXlaLaunch(run_metadata)) + self.assert_(MetadataHasXlaRunOp(run_metadata)) self.assertAllClose(result, np.float32(95), rtol=1e-1) def testCond(self): """Tests that compilation handles switch operators.""" - with self.test_session(config=NoRewriteSessionConfig()) as session: + with self.session(config=NoRewriteSessionConfig()) as session: x = array_ops.placeholder(dtypes.float32) y = array_ops.placeholder(dtypes.float32) c = array_ops.placeholder(dtypes.bool) @@ -350,13 +362,17 @@ class XlaCompilationTest(test.TestCase): # deadlock. run_metadata = config_pb2.RunMetadata() - result = session.run(t, {x: np.float32(2), - y: np.float32(4), - c: True}, - run_metadata=run_metadata, - options=config_pb2.RunOptions( - trace_level=config_pb2.RunOptions.FULL_TRACE)) - self.assert_(MetadataHasXlaLaunch(run_metadata)) + result = test_utils.RunWithWarmup( + session, + t, { + x: np.float32(2), + y: np.float32(4), + c: True + }, + run_metadata=run_metadata, + options=config_pb2.RunOptions( + trace_level=config_pb2.RunOptions.FULL_TRACE)) + self.assert_(MetadataHasXlaRunOp(run_metadata)) self.assertAllClose(result, np.float32(6), rtol=1e-1) def testNestedFunction(self): @@ -378,7 +394,7 @@ class XlaCompilationTest(test.TestCase): inp = array_ops.placeholder(dtypes.float32) out = Entry(inp) - with self.test_session( + with self.session( config=NoRewriteSessionConfig(), graph=g, use_gpu=True) as sess: run_metadata = config_pb2.RunMetadata() val = sess.run(out, @@ -391,7 +407,7 @@ class XlaCompilationTest(test.TestCase): def testLoopDeadlock(self): """Regression test for bug that caused deadlocks in graphs with loops.""" - with self.test_session(config=NoRewriteSessionConfig()) as session: + with self.session(config=NoRewriteSessionConfig()) as session: x = array_ops.placeholder(dtypes.float32) with jit_scope(): y = x + 1.0 @@ -424,11 +440,13 @@ class XlaCompilationTest(test.TestCase): cfg.graph_options.optimizer_options.do_function_inlining = True with session_lib.Session(graph=g, config=cfg) as sess: run_metadata = config_pb2.RunMetadata() - dx_val = sess.run(dx, - feed_dict={x: 100.}, - run_metadata=run_metadata, - options=config_pb2.RunOptions( - trace_level=config_pb2.RunOptions.FULL_TRACE)) + dx_val = test_utils.RunWithWarmup( + sess, + dx, + feed_dict={x: 100.}, + run_metadata=run_metadata, + options=config_pb2.RunOptions( + trace_level=config_pb2.RunOptions.FULL_TRACE)) self.assertAllClose(dx_val, 0.01) return RunMetadataLabels(run_metadata) @@ -441,14 +459,16 @@ class XlaCompilationTest(test.TestCase): self.assertFalse(InLabels(labels, "Log")) self.assertTrue(InLabels(labels, "Reciprocal")) self.assertTrue(InLabels(labels, "Mul")) - self.assertFalse(InLabels(labels, "XlaLaunch")) + self.assertFalse(InLabels(labels, "XlaCompile")) + self.assertFalse(InLabels(labels, "XlaRun")) - # Compile the backprop. One XlaLaunch. + # Compile the backprop. One XlaCompile/XlaRun pair. labels = _Run(compiled=True) self.assertFalse(InLabels(labels, "Log")) self.assertFalse(InLabels(labels, "Reciprocal")) self.assertFalse(InLabels(labels, "Mul")) - self.assertTrue(InLabels(labels, "XlaLaunch")) + self.assertTrue(InLabels(labels, "XlaCompile")) + self.assertTrue(InLabels(labels, "XlaRun")) class ElementWiseFusionTest(test.TestCase): @@ -472,7 +492,8 @@ class ElementWiseFusionTest(test.TestCase): a7 = a6 + a2 run_metadata = config_pb2.RunMetadata() - output = sess.run( + output = test_utils.RunWithWarmup( + sess, a7, { a1: arg0, a2: arg1 @@ -482,15 +503,19 @@ class ElementWiseFusionTest(test.TestCase): trace_level=config_pb2.RunOptions.FULL_TRACE)) labels = RunMetadataLabels(run_metadata) - count = sum("XlaLaunch(" in x for x in labels) - return output, count + xla_compile_count = sum("XlaCompile(" in x for x in labels) + xla_run_count = sum("XlaRun(" in x for x in labels) + self.assertEqual(xla_compile_count, xla_run_count) + + return output, xla_run_count def testElementWiseClustering(self): arg0 = np.random.rand(2, 2).astype(np.float32) arg1 = np.random.rand(2, 2).astype(np.float32) - os.environ["TF_XLA_FLAGS"] = ("--tf_xla_fusion_only=true " - "--tf_xla_cpu_global_jit") + os.environ["TF_XLA_FLAGS"] = ( + "--tf_xla_fusion_only=true " + "--tf_xla_cpu_global_jit " + os.environ.get("TF_XLA_FLAGS", "")) tf_op, tf_count = self.simpleTest(arg0, arg1, config_pb2.OptimizerOptions.OFF) self.assertEqual(0, tf_count) @@ -502,5 +527,60 @@ class ElementWiseFusionTest(test.TestCase): self.assertAllClose(tf_op, tfef_op, rtol=1e-1) +class LazyCompilationTest(test.TestCase): + + def testLazyCompilation(self): + + @function.Defun(compiled=True) + def CompiledFunction(x): + return math_ops.log(x) + + with session_lib.Session(config=NoRewriteSessionConfig()) as sess: + x = array_ops.placeholder(dtypes.float32) + y = CompiledFunction(x) + + run_metadata_before_warmup = config_pb2.RunMetadata() + sess.run( + y, + feed_dict={x: [2., 10.]}, + run_metadata=run_metadata_before_warmup, + options=config_pb2.RunOptions( + trace_level=config_pb2.RunOptions.FULL_TRACE)) + self.assertTrue( + InLabels( + RunMetadataLabels(run_metadata_before_warmup), "_XlaCompile")) + self.assertFalse( + InLabels(RunMetadataLabels(run_metadata_before_warmup), "_XlaRun")) + + # We compile when we see the same shape a second time. + + run_metadata_after_warmup = config_pb2.RunMetadata() + sess.run( + y, + feed_dict={x: [2., 10.]}, + run_metadata=run_metadata_after_warmup, + options=config_pb2.RunOptions( + trace_level=config_pb2.RunOptions.FULL_TRACE)) + self.assertTrue( + InLabels(RunMetadataLabels(run_metadata_after_warmup), "_XlaCompile")) + self.assertTrue( + InLabels(RunMetadataLabels(run_metadata_after_warmup), "_XlaRun")) + + run_metadata_for_new_shape = config_pb2.RunMetadata() + sess.run( + y, + feed_dict={x: [2., 10., 12.]}, + run_metadata=run_metadata_for_new_shape, + options=config_pb2.RunOptions( + trace_level=config_pb2.RunOptions.FULL_TRACE)) + self.assertTrue( + InLabels( + RunMetadataLabels(run_metadata_for_new_shape), "_XlaCompile")) + self.assertFalse( + InLabels(RunMetadataLabels(run_metadata_for_new_shape), "_XlaRun")) + + if __name__ == "__main__": + os.environ["TF_XLA_FLAGS"] = ("--tf_xla_enable_lazy_compilation=true " + + os.environ.get("TF_XLA_FLAGS", "")) test.main() diff --git a/tensorflow/compiler/tests/lstm.py b/tensorflow/compiler/tests/lstm.py index 43c469d0320645cdad6ddc67f3e8cb1374b8e9e5..73b3638e801e7389e83953f6662bcfc78ad86203 100644 --- a/tensorflow/compiler/tests/lstm.py +++ b/tensorflow/compiler/tests/lstm.py @@ -117,7 +117,7 @@ def LSTMLayer(cell_name, weights, m, c, x_seq, pad_seq): def RandomVar(shape, name=None): """Returns a variable of the given shape initialized to random values.""" - return variables.Variable( + return variables.VariableV1( random_ops.random_uniform(shape), dtype=dtypes.float32, name=name) diff --git a/tensorflow/compiler/tests/matrix_band_part_test.py b/tensorflow/compiler/tests/matrix_band_part_test.py index 9222db4b7ebf020c8cee1c0af81e05129fb33c4d..c61965b97fc142ce452cf28def8c937f692d2f84 100644 --- a/tensorflow/compiler/tests/matrix_band_part_test.py +++ b/tensorflow/compiler/tests/matrix_band_part_test.py @@ -17,6 +17,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from absl.testing import parameterized import numpy as np from tensorflow.compiler.tests import xla_test @@ -26,38 +27,167 @@ from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -class MatrixBandPartTest(xla_test.XLATestCase): +class MatrixBandPartTest(xla_test.XLATestCase, parameterized.TestCase): - def _testMatrixBandPart(self, dtype, shape): - with self.cached_session(): - batch_shape = shape[:-2] - mat = np.ones(shape).astype(dtype) - batch_mat = np.tile(mat, batch_shape + [1, 1]) - for lower in -1, 0, 1, shape[-2] - 1: - for upper in -1, 0, 1, shape[-1] - 1: - band_np = mat - if lower >= 0: - band_np = np.triu(band_np, -lower) - if upper >= 0: - band_np = np.tril(band_np, upper) - if batch_shape: - band_np = np.tile(band_np, batch_shape + [1, 1]) - - placeholder = array_ops.placeholder(dtype) - with self.test_scope(): - band = array_ops.matrix_band_part( - placeholder, - constant_op.constant(lower, dtype=dtypes.int32), - constant_op.constant(upper, dtype=dtypes.int32)) - feed_dict = {placeholder: batch_mat} - self.assertAllEqual(band_np, band.eval(feed_dict=feed_dict)) - - def testMatrixBandPart(self): + @parameterized.parameters( + { + 'batch_shape': [], + 'rows': 1, + 'cols': 1 + }, + { + 'batch_shape': [], + 'rows': 1, + 'cols': 2 + }, + { + 'batch_shape': [], + 'rows': 1, + 'cols': 7 + }, + { + 'batch_shape': [], + 'rows': 2, + 'cols': 1 + }, + { + 'batch_shape': [], + 'rows': 2, + 'cols': 2 + }, + { + 'batch_shape': [], + 'rows': 2, + 'cols': 7 + }, + { + 'batch_shape': [], + 'rows': 7, + 'cols': 1 + }, + { + 'batch_shape': [], + 'rows': 7, + 'cols': 2 + }, + { + 'batch_shape': [], + 'rows': 7, + 'cols': 7 + }, + { + 'batch_shape': [2,], + 'rows': 1, + 'cols': 1 + }, + { + 'batch_shape': [2,], + 'rows': 1, + 'cols': 2 + }, + { + 'batch_shape': [2,], + 'rows': 1, + 'cols': 7 + }, + { + 'batch_shape': [2,], + 'rows': 2, + 'cols': 1 + }, + { + 'batch_shape': [2,], + 'rows': 2, + 'cols': 2 + }, + { + 'batch_shape': [2,], + 'rows': 2, + 'cols': 7 + }, + { + 'batch_shape': [2,], + 'rows': 7, + 'cols': 1 + }, + { + 'batch_shape': [2,], + 'rows': 7, + 'cols': 2 + }, + { + 'batch_shape': [2,], + 'rows': 7, + 'cols': 7 + }, + { + 'batch_shape': [1, 3, 2], + 'rows': 1, + 'cols': 1 + }, + { + 'batch_shape': [1, 3, 2], + 'rows': 1, + 'cols': 2 + }, + { + 'batch_shape': [1, 3, 2], + 'rows': 1, + 'cols': 7 + }, + { + 'batch_shape': [1, 3, 2], + 'rows': 2, + 'cols': 1 + }, + { + 'batch_shape': [1, 3, 2], + 'rows': 2, + 'cols': 2 + }, + { + 'batch_shape': [1, 3, 2], + 'rows': 2, + 'cols': 7 + }, + { + 'batch_shape': [1, 3, 2], + 'rows': 7, + 'cols': 1 + }, + { + 'batch_shape': [1, 3, 2], + 'rows': 7, + 'cols': 2 + }, + { + 'batch_shape': [1, 3, 2], + 'rows': 7, + 'cols': 7 + }, + ) + def testMatrixBandPart(self, batch_shape, rows, cols): for dtype in self.float_types: - for batch_shape in [[], [2,], [1, 3, 2]]: - for rows in 1, 2, 7: - for cols in 1, 2, 7: - self._testMatrixBandPart(dtype, batch_shape + [rows, cols]) + with self.cached_session(): + mat = np.ones(batch_shape + [rows, cols]).astype(dtype) + batch_mat = np.tile(mat, batch_shape + [1, 1]) + for lower in -1, 0, 1, rows - 1: + for upper in -1, 0, 1, cols - 1: + band_np = mat + if lower >= 0: + band_np = np.triu(band_np, -lower) + if upper >= 0: + band_np = np.tril(band_np, upper) + if batch_shape: + band_np = np.tile(band_np, batch_shape + [1, 1]) + + placeholder = array_ops.placeholder(dtype) + with self.test_scope(): + band = array_ops.matrix_band_part( + placeholder, constant_op.constant(lower, dtype=dtypes.int32), + constant_op.constant(upper, dtype=dtypes.int32)) + feed_dict = {placeholder: batch_mat} + self.assertAllEqual(band_np, band.eval(feed_dict=feed_dict)) if __name__ == "__main__": diff --git a/tensorflow/compiler/tests/nullary_ops_test.py b/tensorflow/compiler/tests/nullary_ops_test.py index f985c5d2d96e06fc0117f3935d61b19c9e8562b1..38cb2f83efc48ffcdf5403a23e666963b2ea4da1 100644 --- a/tensorflow/compiler/tests/nullary_ops_test.py +++ b/tensorflow/compiler/tests/nullary_ops_test.py @@ -43,18 +43,37 @@ class NullaryOpsTest(xla_test.XLATestCase): output.run() def testConstants(self): - constants = [ - np.float32(42), - np.array([], dtype=np.float32), - np.array([1, 2], dtype=np.float32), - np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32), - np.array([[[1, 2], [3, 4], [5, 6]], [[10, 20], [30, 40], [50, 60]]], - dtype=np.float32), - np.array([[[]], [[]]], dtype=np.float32), - np.array([[[[1]]]], dtype=np.float32), - ] - for c in constants: - self._testNullary(lambda c=c: constant_op.constant(c), expected=c) + for dtype in self.numeric_types: + constants = [ + dtype(42), + np.array([], dtype=dtype), + np.array([1, 2], dtype=dtype), + np.array([7, 7, 7, 7, 7], dtype=dtype), + np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype), + np.array([[[1, 2], [3, 4], [5, 6]], [[10, 20], [30, 40], [50, 60]]], + dtype=dtype), + np.array([[[]], [[]]], dtype=dtype), + np.array([[[[1]]]], dtype=dtype), + ] + for c in constants: + self._testNullary(lambda c=c: constant_op.constant(c), expected=c) + + def testComplexConstants(self): + for dtype in self.complex_types: + constants = [ + dtype(42 + 3j), + np.array([], dtype=dtype), + np.ones([50], dtype=dtype) * (3 + 4j), + np.array([1j, 2 + 1j], dtype=dtype), + np.array([[1, 2j, 7j], [4, 5, 6]], dtype=dtype), + np.array([[[1, 2], [3, 4 + 6j], [5, 6]], + [[10 + 7j, 20], [30, 40], [50, 60]]], + dtype=dtype), + np.array([[[]], [[]]], dtype=dtype), + np.array([[[[1 + 3j]]]], dtype=dtype), + ] + for c in constants: + self._testNullary(lambda c=c: constant_op.constant(c), expected=c) if __name__ == "__main__": diff --git a/tensorflow/compiler/tests/permute_test.py b/tensorflow/compiler/tests/permute_test.py new file mode 100644 index 0000000000000000000000000000000000000000..dbb9274df4f579fbc6076bf55c9307e4d1cb7768 --- /dev/null +++ b/tensorflow/compiler/tests/permute_test.py @@ -0,0 +1,80 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the DataFormatVecPermute operator.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.tests import xla_test +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.platform import test + + +class XlaPermuteOpTest(xla_test.XLATestCase): + + def _runPermuteAndCompare(self, x, src_format, dst_format, expected): + with self.cached_session() as session: + with self.test_scope(): + placeholder = array_ops.placeholder(dtypes.as_dtype(x.dtype), x.shape) + param = {placeholder: x} + output = nn_ops.data_format_vec_permute( + placeholder, src_format=src_format, dst_format=dst_format) + result = session.run(output, param) + self.assertAllEqual(result, expected) + + def testNHWCToNCHW(self): + x = np.array([7, 4, 9, 3], dtype=np.int32) + self._runPermuteAndCompare(x, "NHWC", "NCHW", [7, 3, 4, 9]) + + def testNCHWToNHWC(self): + x = np.array([7, 4, 9, 3], dtype=np.int32) + self._runPermuteAndCompare(x, "NCHW", "NHWC", [7, 9, 3, 4]) + + def testNHWCToHWNC(self): + x = np.array([7, 4, 9, 3], dtype=np.int32) + self._runPermuteAndCompare(x, "NHWC", "HWNC", [4, 9, 7, 3]) + + def testHWNCToNHWC(self): + x = np.array([7, 4, 9, 3], dtype=np.int32) + self._runPermuteAndCompare(x, "HWNC", "NHWC", [9, 7, 4, 3]) + + def testNHWCToNCHW2D(self): + x = np.array([[7, 4], [9, 3], [4, 5], [5, 1]], dtype=np.int32) + self._runPermuteAndCompare(x, "NHWC", "NCHW", + [[7, 4], [5, 1], [9, 3], [4, 5]]) + + def testNHWCToHWNC2D(self): + x = np.array([[7, 4], [9, 3], [4, 5], [5, 1]], dtype=np.int32) + self._runPermuteAndCompare(x, "NHWC", "HWNC", + [[9, 3], [4, 5], [7, 4], [5, 1]]) + + def testHWNCToNHWC2D(self): + x = np.array([[7, 4], [9, 3], [4, 5], [5, 1]], dtype=np.int32) + self._runPermuteAndCompare(x, "HWNC", "NHWC", + [[4, 5], [7, 4], [9, 3], [5, 1]]) + + def testNCHWToNHWC2D(self): + x = np.array([[7, 4], [9, 3], [4, 5], [5, 1]], dtype=np.int32) + self._runPermuteAndCompare(x, "NCHW", "NHWC", + [[7, 4], [4, 5], [5, 1], [9, 3]]) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/qr_op_test.py b/tensorflow/compiler/tests/qr_op_test.py index 3a268978bfd72d08a7d3a7cc61a116dac543cda5..236b1b881dcaffc1a5b0c6395f0605c1d7ef0269 100644 --- a/tensorflow/compiler/tests/qr_op_test.py +++ b/tensorflow/compiler/tests/qr_op_test.py @@ -101,8 +101,8 @@ class QrOpTest(xla_test.XLATestCase, parameterized.TestCase): @parameterized.parameters(*PARAMS) def testQR(self, rows, cols, dtype): - # TODO(b/111317468): implement full_matrices=False, test other types. - for full_matrices in [True]: + # TODO(b/111317468): Test other types. + for full_matrices in [True, False]: # Only tests the (3, 2) case for small numbers of rows/columns. for batch_dims in [(), (3,)] + [(3, 2)] * (max(rows, cols) < 10): self._test(dtype, batch_dims + (rows, cols), full_matrices) diff --git a/tensorflow/compiler/tests/quantized_ops_test.py b/tensorflow/compiler/tests/quantized_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..80c338513bc9ff6b8e56c5ad6b904af9e06a3715 --- /dev/null +++ b/tensorflow/compiler/tests/quantized_ops_test.py @@ -0,0 +1,48 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for quantized operations.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.tests import xla_test +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import googletest + + +class QuantizedOpsTest(xla_test.XLATestCase): + + # Verify that quantized types can be clustered by XLA. + def testQuantizedTypeRoundtrip(self): + with self.cached_session() as session: + for dtype in self.quantized_tf_types: + in_values = np.array([1, 2, 3, 4, 5, 6]) + expected = [[1, 2], [3, 4], [5, 6]] + with self.test_scope(): + p = array_ops.placeholder(dtype=dtypes.int32) + x = math_ops.cast(p, dtype) + x = array_ops.reshape(x, [3, 2]) + + value = session.run(x, {p: in_values}) + self.assertAllEqual(value, expected) + + +if __name__ == "__main__": + googletest.main() diff --git a/tensorflow/compiler/tests/random_ops_test.py b/tensorflow/compiler/tests/random_ops_test.py index 6e183441179ebf2e8c063b333f9328d6fa86cc88..36ef6ed5fee78bad10bb1ee0bf3eb7824d05c206 100644 --- a/tensorflow/compiler/tests/random_ops_test.py +++ b/tensorflow/compiler/tests/random_ops_test.py @@ -35,7 +35,8 @@ class RandomOpsTest(xla_test.XLATestCase): """Test cases for random-number generating operators.""" def _random_types(self): - return set(self.numeric_types) - set(self.complex_types) + return set(self.numeric_types) - set( + self.complex_types) - {np.uint8, np.int8} def _testRngIsNotConstant(self, rng, dtype): # Tests that 'rng' does not always return the same value. @@ -68,9 +69,8 @@ class RandomOpsTest(xla_test.XLATestCase): def rng(dtype): return random_ops.random_normal(shape=[2], dtype=dtype) - # TODO(b/34339814): implement inverse erf support for non-F32 types. - dtype = dtypes.float32 - self._testRngIsNotConstant(rng, dtype) + for dtype in self._random_types() & self.float_types: + self._testRngIsNotConstant(rng, dtype) def testRandomUniformIsInRange(self): for dtype in self._random_types(): @@ -92,13 +92,13 @@ class RandomOpsTest(xla_test.XLATestCase): def rng(dtype): return random_ops.truncated_normal(shape=[2], dtype=dtype) - # TODO(b/34339814): implement inverse erf support for non-F32 types. - self._testRngIsNotConstant(rng, dtypes.float32) + for dtype in self._random_types() & self.float_types: + self._testRngIsNotConstant(rng, dtype) def testTruncatedNormalIsInRange(self): count = 10000000 - # TODO(b/34339814): implement inverse erf support for non-F32 types. - for dtype in [dtypes.float32]: + # TODO(b/34339814): make this test work with 16 bit float types. + for dtype in self._random_types() & {dtypes.float32, dtypes.float64}: with self.cached_session() as sess: with self.test_scope(): x = random_ops.truncated_normal(shape=[count], dtype=dtype) @@ -144,9 +144,6 @@ class RandomOpsTest(xla_test.XLATestCase): self.assertAllClose(actual_variance, expected_variance, rtol=2*1e-3) def testShuffle1d(self): - # TODO(b/26783907): this test requires the CPU backend to implement sort. - if self.device in ["XLA_CPU"]: - return with self.cached_session() as sess: with self.test_scope(): x = math_ops.range(1 << 16) diff --git a/tensorflow/compiler/tests/randomized_tests.cc b/tensorflow/compiler/tests/randomized_tests.cc index c0ea242044540b1cef44186880ba3cd92b8849d6..dc119fb0f8a41a3772a8c9508bf2db657f57de88 100644 --- a/tensorflow/compiler/tests/randomized_tests.cc +++ b/tensorflow/compiler/tests/randomized_tests.cc @@ -45,6 +45,9 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_set.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/core/common_runtime/device.h" @@ -61,8 +64,6 @@ limitations under the License. #include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/core/stringpiece.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/public/session.h" #include "tensorflow/core/public/session_options.h" @@ -81,7 +82,7 @@ string* tf_xla_test_device_ptr; // initial value set in main() bool tf_xla_test_use_jit = true; string LocalDeviceToFullDeviceName(const string& device) { - return strings::StrCat("/job:localhost/replica:0/task:0/device:", device); + return absl::StrCat("/job:localhost/replica:0/task:0/device:", device); } constexpr std::array kAllXlaTypes = { @@ -107,11 +108,12 @@ class OpTestBuilder { // Sets an attribute. template - OpTestBuilder& Attr(StringPiece attr_name, T&& value); + OpTestBuilder& Attr(absl::string_view attr_name, T&& value); // Overload needed to allow {...} expressions for value. template - OpTestBuilder& Attr(StringPiece attr_name, std::initializer_list value); + OpTestBuilder& Attr(absl::string_view attr_name, + std::initializer_list value); // Adds nodes that executes the operator under test on 'device' to 'graphdef'. // If 'use_jit' is true, marks the operator under test to be compiled by XLA. @@ -185,13 +187,13 @@ OpTestBuilder& OpTestBuilder::RandomUniqueInput(DataType type, } template -OpTestBuilder& OpTestBuilder::Attr(StringPiece attr_name, T&& value) { +OpTestBuilder& OpTestBuilder::Attr(absl::string_view attr_name, T&& value) { AddNodeAttr(attr_name, std::forward(value), &node_def_); return *this; } template -OpTestBuilder& OpTestBuilder::Attr(StringPiece attr_name, +OpTestBuilder& OpTestBuilder::Attr(absl::string_view attr_name, std::initializer_list value) { Attr>(attr_name, std::move(value)); return *this; @@ -209,7 +211,7 @@ Status OpTestBuilder::BuildGraph(const string& name_prefix, NodeDef* test_def = graphdef->add_node(); *test_def = node_def_; - test_def->set_name(strings::StrCat(name_prefix, "_op_under_test")); + test_def->set_name(absl::StrCat(name_prefix, "_op_under_test")); test_def->set_device(device); AddDefaultsToNodeDef(*op_def, test_def); if (use_jit) { @@ -224,7 +226,7 @@ Status OpTestBuilder::BuildGraph(const string& name_prefix, // Build feed and fetch nodes. for (int i = 0; i < input_types.size(); ++i) { NodeDef* def = graphdef->add_node(); - string name = strings::StrCat(name_prefix, "_input_", i); + string name = absl::StrCat(name_prefix, "_input_", i); TF_RETURN_IF_ERROR(NodeDefBuilder(name, "Placeholder") .Device(device) .Attr("dtype", input_types[i]) @@ -235,7 +237,7 @@ Status OpTestBuilder::BuildGraph(const string& name_prefix, for (int i = 0; i < output_types.size(); ++i) { NodeDef* def = graphdef->add_node(); - string name = strings::StrCat(name_prefix, "_output_", i); + string name = absl::StrCat(name_prefix, "_output_", i); TF_RETURN_IF_ERROR(NodeDefBuilder(name, "Identity") .Device(device) .Attr("T", output_types[i]) @@ -275,13 +277,13 @@ class OpTest : public ::testing::Test { // Select a random element from 'candidates'. template - T Choose(gtl::ArraySlice candidates); + T Choose(absl::Span candidates); static constexpr int kDefaultMaxRank = 5; static constexpr int64 kDefaultMaxDimensionSize = 256LL; // Returns true if 'dims' have a size less than tf_xla_max_tensor_size. - bool TensorSizeIsOk(gtl::ArraySlice dims); + bool TensorSizeIsOk(absl::Span dims); // Returns a random dimension size, in the range [min, max). int64 RandomDim(int64 min = 0, int64 max = kDefaultMaxDimensionSize); @@ -307,11 +309,11 @@ class OpTest : public ::testing::Test { // of the type's range. If the shape is omitted, a random shape is used. // TODO(phawkins): generalize this code to a caller-supplied distribution. Tensor RandomTensor(DataType dtype, bool needs_unique_values, - gtl::ArraySlice shape); + absl::Span shape); Tensor RandomTensor(DataType dtype); // Like RandomTensor, but uses values >= 0. - Tensor RandomNonNegativeTensor(DataType dtype, gtl::ArraySlice shape); + Tensor RandomNonNegativeTensor(DataType dtype, absl::Span shape); Tensor RandomNonNegativeTensor(DataType dtype); // Returns a random subset of the integers in the range [0, rank), suitable @@ -415,7 +417,7 @@ void OpTest::Repeatedly(const std::function& fn) { } template -T OpTest::Choose(gtl::ArraySlice candidates) { +T OpTest::Choose(absl::Span candidates) { std::uniform_int_distribution d(0, candidates.size() - 1); return candidates[d(generator())]; } @@ -425,7 +427,7 @@ int64 OpTest::RandomDim(int64 min, int64 max) { return size_distribution(generator()); } -bool OpTest::TensorSizeIsOk(gtl::ArraySlice dims) { +bool OpTest::TensorSizeIsOk(absl::Span dims) { int64 size = 1LL; for (int64 dim : dims) { size *= dim; @@ -451,11 +453,11 @@ std::vector OpTest::RandomDims(int min_rank, int max_rank, } Tensor OpTest::RandomTensor(DataType dtype, bool needs_unique_values, - gtl::ArraySlice shape) { + absl::Span shape) { Tensor tensor(dtype, TensorShape(shape)); switch (dtype) { case DT_FLOAT: { - gtl::FlatSet already_generated; + absl::flat_hash_set already_generated; std::uniform_real_distribution distribution(-1.0f, 1.0f); test::FillFn(&tensor, [&](int i) -> float { float generated; @@ -468,7 +470,7 @@ Tensor OpTest::RandomTensor(DataType dtype, bool needs_unique_values, break; } case DT_DOUBLE: { - gtl::FlatSet already_generated; + absl::flat_hash_set already_generated; std::uniform_real_distribution distribution(-1.0, 1.0); test::FillFn(&tensor, [&](int i) -> double { double generated; @@ -481,7 +483,7 @@ Tensor OpTest::RandomTensor(DataType dtype, bool needs_unique_values, break; } case DT_COMPLEX64: { - gtl::FlatSet> already_generated; + absl::flat_hash_set> already_generated; std::uniform_real_distribution distribution(-1.0f, 1.0f); test::FillFn(&tensor, [&](int i) { complex64 generated; @@ -498,7 +500,7 @@ Tensor OpTest::RandomTensor(DataType dtype, bool needs_unique_values, break; } case DT_INT32: { - gtl::FlatSet already_generated; + absl::flat_hash_set already_generated; std::uniform_int_distribution distribution(-(1 << 20), 1 << 20); test::FillFn(&tensor, [&](int i) -> int32 { int32 generated; @@ -511,7 +513,7 @@ Tensor OpTest::RandomTensor(DataType dtype, bool needs_unique_values, break; } case DT_INT64: { - gtl::FlatSet already_generated; + absl::flat_hash_set already_generated; std::uniform_int_distribution distribution(-(1LL << 40), 1LL << 40); test::FillFn(&tensor, [&](int i) -> int64 { @@ -525,7 +527,7 @@ Tensor OpTest::RandomTensor(DataType dtype, bool needs_unique_values, break; } case DT_BOOL: { - gtl::FlatSet already_generated; + absl::flat_hash_set already_generated; std::bernoulli_distribution distribution; test::FillFn(&tensor, [&](int i) -> bool { bool generated; @@ -548,7 +550,7 @@ Tensor OpTest::RandomTensor(DataType dtype) { } Tensor OpTest::RandomNonNegativeTensor(DataType dtype, - gtl::ArraySlice shape) { + absl::Span shape) { Tensor tensor(dtype, TensorShape(shape)); switch (dtype) { case DT_FLOAT: { @@ -726,11 +728,11 @@ bool IsClose(const complex64& x, const complex64& y, double atol, template string Str(T x) { - return strings::StrCat(x); + return absl::StrCat(x); } template <> string Str(complex64 x) { - return strings::StrCat("(", x.real(), ", ", x.imag(), ")"); + return absl::StrCat("(", x.real(), ", ", x.imag(), ")"); } template @@ -740,11 +742,11 @@ Status TensorsAreCloseImpl(const Tensor& x, const Tensor& y, double atol, auto Ty = y.flat(); for (int i = 0; i < Tx.size(); ++i) { if (!IsClose(Tx(i), Ty(i), atol, rtol)) { - return errors::InvalidArgument(strings::StrCat( - i, "-th tensor element isn't close: ", Str(Tx(i)), " vs. ", - Str(Ty(i)), ". x = ", x.DebugString(), "y = ", y.DebugString(), - "atol = ", atol, " rtol = ", rtol, - " tol = ", atol + rtol * Abs(Tx(i)))); + return errors::InvalidArgument( + absl::StrCat(i, "-th tensor element isn't close: ", Str(Tx(i)), + " vs. ", Str(Ty(i)), ". x = ", x.DebugString(), + "y = ", y.DebugString(), "atol = ", atol, + " rtol = ", rtol, " tol = ", atol + rtol * Abs(Tx(i)))); } } return Status::OK(); @@ -756,7 +758,7 @@ Status TensorsAreEqualImpl(const Tensor& x, const Tensor& y) { auto Ty = y.flat(); for (int i = 0; i < Tx.size(); ++i) { if (Tx(i) != Ty(i)) { - return errors::InvalidArgument(strings::StrCat( + return errors::InvalidArgument(absl::StrCat( i, "-th tensor element isn't equal: ", Tx(i), " vs. ", Ty(i), ". x = ", x.DebugString(), "y = ", y.DebugString())); } @@ -771,14 +773,14 @@ Status TensorsAreEqualImpl(const Tensor& x, const Tensor& y) { Status TensorsAreClose(const Tensor& a, const Tensor& b, double atol, double rtol) { if (a.dtype() != b.dtype()) { - return errors::InvalidArgument(strings::StrCat( + return errors::InvalidArgument(absl::StrCat( "Tensors have different types: ", DataTypeString(a.dtype()), " and ", DataTypeString(b.dtype()))); } if (!a.IsSameSize(b)) { - return errors::InvalidArgument(strings::StrCat( - "Tensors have different shapes: ", a.shape().DebugString(), " and ", - b.shape().DebugString())); + return errors::InvalidArgument( + absl::StrCat("Tensors have different shapes: ", a.shape().DebugString(), + " and ", b.shape().DebugString())); } switch (a.dtype()) { @@ -827,7 +829,7 @@ OpTest::TestResult OpTest::ExpectTfAndXlaOutputsAreClose( } string cpu_device = - LocalDeviceToFullDeviceName(strings::StrCat(DEVICE_CPU, ":0")); + LocalDeviceToFullDeviceName(absl::StrCat(DEVICE_CPU, ":0")); string test_device = LocalDeviceToFullDeviceName(*tf_xla_test_device_ptr); DeviceNameUtils::ParsedName parsed_name; @@ -842,7 +844,7 @@ OpTest::TestResult OpTest::ExpectTfAndXlaOutputsAreClose( std::vector expected_inputs, test_inputs; std::vector expected_fetches, test_fetches; Status status = builder.BuildGraph( - strings::StrCat("test", num_tests_, "_expected"), cpu_device, + absl::StrCat("test", num_tests_, "_expected"), cpu_device, /* use_jit= */ false, &graph, /* test_node_def= */ nullptr, &expected_inputs, &expected_fetches); if (!status.ok()) { @@ -851,7 +853,7 @@ OpTest::TestResult OpTest::ExpectTfAndXlaOutputsAreClose( } NodeDef* node_def; - status = builder.BuildGraph(strings::StrCat("test", num_tests_, "_test"), + status = builder.BuildGraph(absl::StrCat("test", num_tests_, "_test"), test_device, tf_xla_test_use_jit, &graph, &node_def, &test_inputs, &test_fetches); if (!status.ok()) { @@ -1818,7 +1820,7 @@ TEST_F(OpTest, Diag) { do { dims = RandomDims(1); size = TensorShape(dims).num_elements(); - } while (size * size < tf_xla_max_tensor_size); + } while (size * size > tf_xla_max_tensor_size); return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("Diag").RandomInput(type, dims).Attr("T", type)); }); @@ -1884,7 +1886,8 @@ TEST_F(OpTest, DynamicStitch) { for (int i = 0; i < n; ++i) { TensorShape shape(index_dims[i]); Tensor t = test::AsTensor( - gtl::ArraySlice(indices, pos, shape.num_elements()), shape); + absl::Span(indices).subspan(pos, shape.num_elements()), + shape); builder.Input(t); pos += t.NumElements(); } diff --git a/tensorflow/compiler/tests/reduce_ops_test.py b/tensorflow/compiler/tests/reduce_ops_test.py index 5ae5b1bc1df76e6d0267a9a9ac18e7bc4725ec7b..132c59c32c9db0c8759bdbb31f8613c3ef88b485 100644 --- a/tensorflow/compiler/tests/reduce_ops_test.py +++ b/tensorflow/compiler/tests/reduce_ops_test.py @@ -219,7 +219,7 @@ class ReduceOpPrecisionTest(xla_test.XLATestCase): bf16_max = np.float32(dtypes.bfloat16.max) f32_max = dtypes.float32.max - value = min(bf16_max, f32_max - bf16_max) + value = min(bf16_max, f32_max - bf16_max) / 2 self._testReduceSum( dtypes.bfloat16.as_numpy_dtype(value), dtypes.bfloat16.as_numpy_dtype, itertools.permutations([bf16_max, value, bf16_max * (-1.0)], 3)) diff --git a/tensorflow/compiler/tests/reshape_op_test.py b/tensorflow/compiler/tests/reshape_op_test.py index 84c67779400f7a800bd88abc32d95058a6c0904d..96e0b074754032dd64c479b5e587b664ff066e2b 100644 --- a/tensorflow/compiler/tests/reshape_op_test.py +++ b/tensorflow/compiler/tests/reshape_op_test.py @@ -33,7 +33,7 @@ class ReshapeTest(xla_test.XLATestCase, parameterized.TestCase): ('64_bit_index', dtypes.int64)) def testBasic(self, index_dtype): for dtype in self.numeric_types: - with self.test_session(): + with self.cached_session(): i = array_ops.placeholder(dtype, shape=[2, 3]) with self.test_scope(): shape = constant_op.constant([3, 2], dtype=index_dtype) diff --git a/tensorflow/compiler/tests/reverse_sequence_op_test.py b/tensorflow/compiler/tests/reverse_sequence_op_test.py index 60c2337743b44e9bad61c4d65280eb2b1a1ad9ea..abc822ef363e5d83c99bb963582662ccfce4cd6d 100644 --- a/tensorflow/compiler/tests/reverse_sequence_op_test.py +++ b/tensorflow/compiler/tests/reverse_sequence_op_test.py @@ -85,7 +85,7 @@ class ReverseSequenceTest(xla_test.XLATestCase): def testSeqLength(self): for dtype in self.all_types: - for seq_dtype in self.int_types: + for seq_dtype in self.all_types & {np.int32, np.int64}: self._testBasic(dtype, seq_dtype) diff --git a/tensorflow/compiler/tests/slice_ops_test.py b/tensorflow/compiler/tests/slice_ops_test.py index 8f10c2fe864f6331299e60ddd25a486dfa478c37..2c611a959e1d71c53e44bc92c31258153d01507d 100644 --- a/tensorflow/compiler/tests/slice_ops_test.py +++ b/tensorflow/compiler/tests/slice_ops_test.py @@ -40,6 +40,19 @@ class SliceTest(xla_test.XLATestCase): self.assertAllEqual([2, 3, 4, 5], result) + def testZeroSlice(self): + for dtype in self.numeric_types: + with self.cached_session(): + i = array_ops.placeholder(dtype, shape=[2]) + with self.test_scope(): + o = array_ops.slice(i, [0], [0]) + params = { + i: [0, 1], + } + result = o.eval(feed_dict=params) + + self.assertAllEqual([], result) + def test3D(self): for dtype in self.numeric_types: with self.cached_session(): diff --git a/tensorflow/compiler/tests/sort_ops_test.py b/tensorflow/compiler/tests/sort_ops_test.py index 51c04b5c4796474700a92a8b23a1cbdf533fcbb4..3e499c2fb176a6d63fe3590e18a4a90e461e096a 100644 --- a/tensorflow/compiler/tests/sort_ops_test.py +++ b/tensorflow/compiler/tests/sort_ops_test.py @@ -48,22 +48,33 @@ class XlaSortOpTest(xla_test.XLATestCase): self.assertAllClose(v, result, rtol=1e-3) def testSort(self): - # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU. - if self.device in ["XLA_CPU", "XLA_GPU"]: - return - - supported_types = set([dtypes.bfloat16.as_numpy_dtype, np.float32]) + supported_types = set( + [dtypes.bfloat16.as_numpy_dtype, np.float32, np.int32, np.uint32]) for dtype in supported_types.intersection(self.numeric_types): x = np.arange(101, dtype=dtype) np.random.shuffle(x) self._assertOpOutputMatchesExpected( xla.sort, [x], expected=[np.arange(101, dtype=dtype)]) - def testTopK(self): - # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU. - if self.device in ["XLA_CPU", "XLA_GPU"]: - return + def testKeyValueSort(self): + supported_key_types = set( + [dtypes.bfloat16.as_numpy_dtype, np.float32, np.int32, np.uint32]) + supported_value_types = set( + [dtypes.bfloat16.as_numpy_dtype, np.float32, np.int32, np.uint32, + dtypes.int64.as_numpy_dtype, dtypes.uint64.as_numpy_dtype]) + for key_type in supported_key_types.intersection(self.numeric_types): + for value_type in supported_value_types.intersection(self.numeric_types): + x = np.arange(101, dtype=key_type) + np.random.shuffle(x) + y = (-x).astype(value_type) + self._assertOpOutputMatchesExpected( + xla.key_value_sort, [x, y], + expected=[ + np.arange(101, dtype=key_type), + -np.arange(101, dtype=value_type) + ]) + def testTopK(self): supported_types = set( [dtypes.bfloat16.as_numpy_dtype, np.float32, np.int32, np.uint32]) for dtype in supported_types.intersection(self.numeric_types): @@ -89,10 +100,6 @@ class XlaSortOpTest(xla_test.XLATestCase): expected=[x[indices].astype(dtype), indices]) def testTopK2D(self): - # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU. - if self.device in ["XLA_CPU", "XLA_GPU"]: - return - supported_types = set( [dtypes.bfloat16.as_numpy_dtype, np.float32, np.int32, np.uint32]) for dtype in supported_types.intersection(self.numeric_types): @@ -122,10 +129,6 @@ class XlaSortOpTest(xla_test.XLATestCase): def testTopKZeros(self): """Tests that positive and negative zeros sort correctly.""" - # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU. - if self.device in ["XLA_CPU", "XLA_GPU"]: - return - # Only bfloat16 is implemented. bfloat16 = dtypes.bfloat16.as_numpy_dtype if bfloat16 not in self.numeric_types: @@ -144,10 +147,6 @@ class XlaSortOpTest(xla_test.XLATestCase): def testTopKInfinities(self): """Tests that positive and negative infinity sort correctly.""" - # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU. - if self.device in ["XLA_CPU", "XLA_GPU"]: - return - # Only bfloat16 is implemented. bfloat16 = dtypes.bfloat16.as_numpy_dtype if bfloat16 not in self.numeric_types: diff --git a/tensorflow/compiler/tests/stateless_random_ops_test.py b/tensorflow/compiler/tests/stateless_random_ops_test.py index 1bea7d9355e40c5a71f848dabc0fa7fa760429d2..e8741bc468585ff9fb049dcd87700f8048d74026 100644 --- a/tensorflow/compiler/tests/stateless_random_ops_test.py +++ b/tensorflow/compiler/tests/stateless_random_ops_test.py @@ -34,7 +34,7 @@ class StatelessRandomOpsTest(xla_test.XLATestCase): """Test cases for stateless random-number generator operators.""" def _random_types(self): - return [dtypes.float32] + return self.float_types & {dtypes.float32, dtypes.float64} def testDeterminism(self): # Stateless values should be equal iff the seeds are equal (roughly) @@ -91,7 +91,7 @@ class StatelessRandomOpsTest(xla_test.XLATestCase): with self.cached_session() as sess, self.test_scope(): for dtype in self._random_types(): seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) - x = stateless.stateless_random_uniform( + x = stateless.stateless_random_normal( shape=[10000], seed=seed_t, dtype=dtype) y = sess.run(x, {seed_t: [0x12345678, 0xabcdef12]}) self.assertTrue(np.all(np.isfinite(y))) @@ -124,8 +124,7 @@ class StatelessRandomOpsTest(xla_test.XLATestCase): self.assertTrue(self._anderson_darling(y) < 2.492) def testTruncatedNormalIsInRange(self): - # TODO(b/34339814): implement inverse erf support for non-F32 types. - for dtype in [dtypes.float32]: + for dtype in self._random_types(): with self.cached_session() as sess, self.test_scope(): seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) n = 10000000 @@ -159,7 +158,7 @@ class StatelessRandomOpsTest(xla_test.XLATestCase): # Department of Scientific Computing website. Florida State University. expected_mean = mu + (normal_pdf(alpha) - normal_pdf(beta)) / z * sigma actual_mean = np.mean(y) - self.assertAllClose(actual_mean, expected_mean, atol=2e-4) + self.assertAllClose(actual_mean, expected_mean, atol=5e-4) expected_median = mu + probit( (normal_cdf(alpha) + normal_cdf(beta)) / 2.) * sigma diff --git a/tensorflow/compiler/tests/tensor_array_ops_test.py b/tensorflow/compiler/tests/tensor_array_ops_test.py index 78244d0b366d9128a4c59f786e4c5ac12e743b75..46ca371c8abf1cb4710717a183ee12820c4c4ca0 100644 --- a/tensorflow/compiler/tests/tensor_array_ops_test.py +++ b/tensorflow/compiler/tests/tensor_array_ops_test.py @@ -920,6 +920,34 @@ class TensorArrayTest(xla_test.XLATestCase): def testTensorArrayEvalEmptyWithDefault(self): self._testTensorArrayEvalEmptyWithDefault() + def _testTensorArrayScatterRead(self, tf_dtype): + with self.cached_session() as session, self.test_scope(): + convert = _make_converter(tf_dtype) + + ta = tensor_array_ops.TensorArray( + dtype=tf_dtype, + tensor_array_name="foo", + size=10) + + indices = constant_op.constant([1, 8]) + value = constant_op.constant(convert([[1.0, -1.0], [10.0, -10.0]])) + id0 = array_ops.placeholder(dtypes.int32) + id1 = array_ops.placeholder(dtypes.int32) + + w = ta.scatter(indices, value) + r0 = w.read(id0) + r1 = w.read(id1) + + # Test aggregation of read + read_vals = session.run([r0, r1], feed_dict={id0: 1, id1: 8}) + self.assertAllEqual(convert([1.0, -1.0]), read_vals[0]) + self.assertAllEqual(convert([10.0, -10.0]), read_vals[1]) + + def testTensorArrayScatterRead(self): + for dtype in self.numeric_tf_types: + self._testTensorArrayScatterRead(dtype) + self._testTensorArrayScatterRead(dtypes.bool) + def testTensorArrayScatterReadAndGradients(self): with self.cached_session() as session, self.test_scope(): ta = tensor_array_ops.TensorArray( @@ -929,15 +957,18 @@ class TensorArrayTest(xla_test.XLATestCase): indices = constant_op.constant([1, 8]) value = constant_op.constant([[1.0, -1.0], [10.0, -10.0]]) + id0 = array_ops.placeholder(dtypes.int32) + id1 = array_ops.placeholder(dtypes.int32) w = ta.scatter(indices, value) - r0 = w.read(1) - r1 = w.read(8) + r0 = w.read(id0) + r1 = w.read(id1) # Test combined gradients + aggregation of read(0). grad = gradients_impl.gradients( ys=[r0, r1], xs=[value], grad_ys=[[2.0, 3.0], [4.0, 5.0]]) - read_vals, grad_vals = session.run([[r0, r1], grad]) + read_vals, grad_vals = session.run([[r0, r1], grad], + feed_dict={id0: 1, id1: 8}) self.assertEqual(len(read_vals), 2) self.assertEqual(len(grad_vals), 1) diff --git a/tensorflow/compiler/tests/tensor_list_ops_test.py b/tensorflow/compiler/tests/tensor_list_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..5c079d595c440cac644f5461154509abe7b1d1ed --- /dev/null +++ b/tensorflow/compiler/tests/tensor_list_ops_test.py @@ -0,0 +1,96 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for ops which manipulate lists of tensors via bridge.""" + +# pylint: disable=g-bad-name +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import numpy as np +from tensorflow.compiler.tests import xla_test +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import list_ops +from tensorflow.python.platform import test + + +def scalar_shape(): + return ops.convert_to_tensor([], dtype=dtypes.int32) + + +class ListOpsTest(xla_test.XLATestCase): + + def testElementShape(self): + with self.cached_session() as sess, self.test_scope(): + dim = array_ops.placeholder(dtypes.int32) + l = list_ops.tensor_list_reserve( + element_shape=(dim, 15), num_elements=20, + element_dtype=dtypes.float32) + e32 = list_ops.tensor_list_element_shape(l, shape_type=dtypes.int32) + e64 = list_ops.tensor_list_element_shape(l, shape_type=dtypes.int64) + self.assertAllEqual(sess.run(e32, {dim: 10}), (10, 15)) + self.assertAllEqual(sess.run(e64, {dim: 7}), (7, 15)) + + def testPushPop(self): + with self.cached_session() as sess, self.test_scope(): + num = array_ops.placeholder(dtypes.int32) + l = list_ops.tensor_list_reserve( + element_shape=(7, 15), num_elements=num, element_dtype=dtypes.float32) + l = list_ops.tensor_list_push_back( + l, constant_op.constant(1.0, shape=(7, 15))) + l = list_ops.tensor_list_push_back( + l, constant_op.constant(2.0, shape=(7, 15))) + l, e2 = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32) + _, e1 = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32) + self.assertAllEqual(sess.run(e2, {num: 10}), 2.0 * np.ones((7, 15))) + self.assertAllEqual(sess.run(e1, {num: 10}), 1.0 * np.ones((7, 15))) + + def testPushPopSeparateLists(self): + with self.cached_session() as sess, self.test_scope(): + num = array_ops.placeholder(dtypes.int32) + l = list_ops.tensor_list_reserve( + element_shape=scalar_shape(), + num_elements=num, + element_dtype=dtypes.float32) + l = list_ops.tensor_list_push_back(l, constant_op.constant(1.0)) + l2 = list_ops.tensor_list_push_back(l, constant_op.constant(2.0)) + l3 = list_ops.tensor_list_push_back(l, constant_op.constant(3.0)) + _, e11 = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32) + l2, e21 = list_ops.tensor_list_pop_back(l2, element_dtype=dtypes.float32) + l2, e22 = list_ops.tensor_list_pop_back(l2, element_dtype=dtypes.float32) + l3, e31 = list_ops.tensor_list_pop_back(l3, element_dtype=dtypes.float32) + l3, e32 = list_ops.tensor_list_pop_back(l3, element_dtype=dtypes.float32) + result = sess.run([e11, [e21, e22], [e31, e32]], {num: 20}) + self.assertEqual(result, [1.0, [2.0, 1.0], [3.0, 1.0]]) + + def testEmptyTensorList(self): + dim = 7 + with self.cached_session() as sess, self.test_scope(): + p = array_ops.placeholder(dtypes.int32) + l = list_ops.empty_tensor_list( + element_shape=(p, 15), element_dtype=dtypes.float32) + l = list_ops.tensor_list_push_back( + l, constant_op.constant(1.0, shape=(dim, 15))) + _, e = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32) + with self.assertRaisesRegexp(errors.InvalidArgumentError, + "Use TensorListReserve instead"): + self.assertEqual(sess.run(e, {p: dim}), 1.0 * np.ones((dim, 15))) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/ternary_ops_test.py b/tensorflow/compiler/tests/ternary_ops_test.py index 55a992195f2df72677b77757ae86171fa662439f..98a07709c611178effd7794ba58ba89770c6d77f 100644 --- a/tensorflow/compiler/tests/ternary_ops_test.py +++ b/tensorflow/compiler/tests/ternary_ops_test.py @@ -122,8 +122,7 @@ class TernaryOpsTest(xla_test.XLATestCase): expected=np.array([[2], [5]], dtype=dtype)) def testClipByValue(self): - # TODO(b/78258593): enable integer types here too. - for dtype in self.float_types: + for dtype in self.numeric_types - self.complex_types: test_cases = [ (np.array([2, 4, 5], dtype=dtype), dtype(7)), # (dtype(1), np.array([2, 4, 5], dtype=dtype)), # diff --git a/tensorflow/compiler/tests/test_utils.py b/tensorflow/compiler/tests/test_utils.py index 6abde18ea91f16d153a154b94effab037a911c6c..0e77dbf1a79d3dbacb77bab8b8e3df9bcc6287e1 100644 --- a/tensorflow/compiler/tests/test_utils.py +++ b/tensorflow/compiler/tests/test_utils.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function import numpy as np +from six.moves import xrange # pylint: disable=redefined-builtin def ConvertBetweenDataFormats(x, data_format_src, data_format_dst): @@ -61,3 +62,14 @@ def PermuteDimsBetweenDataFormats(dims, data_format_src, data_format_dst): dim_map = {d: i for i, d in enumerate(data_format_src)} permuted_dims = [dims[dim_map[d]] for d in data_format_dst] return permuted_dims + + +_JIT_WARMUP_ITERATIONS = 10 + + +def RunWithWarmup(sess, op_to_run, feed_dict, options=None, run_metadata=None): + """Runs a graph a few times to ensure that its clusters are compiled.""" + for _ in xrange(0, _JIT_WARMUP_ITERATIONS): + sess.run(op_to_run, feed_dict, options=options) + return sess.run( + op_to_run, feed_dict, options=options, run_metadata=run_metadata) diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py index 5b0e57f83ff4b5a8d1891bef0675074bd67addce..77f6eee0cf8ddc9b76f150e1038bf66da34c5218 100644 --- a/tensorflow/compiler/tests/unary_ops_test.py +++ b/tensorflow/compiler/tests/unary_ops_test.py @@ -84,7 +84,7 @@ class UnaryOpsTest(xla_test.XLATestCase): self.assertAllClose(result[i], expected[i], rtol, atol) def testAllTypeOps(self): - for dtype in self.numeric_types: + for dtype in self.numeric_types - {np.int8, np.uint8}: self._assertOpOutputMatchesExpected( array_ops.diag, np.array([1, 2, 3, 4], dtype=dtype), np.array( @@ -158,9 +158,6 @@ class UnaryOpsTest(xla_test.XLATestCase): def testFloatOps(self): for dtype in self.float_types: - # TODO(b/77694432): Half test failed on CPU, last ran on 04-06-2018. - if dtype == np.float16 and self.device == "XLA_CPU": - continue x = np.arange(-0.90, 0.90, 0.25) self._assertOpOutputMatchesExpected( math_ops.acos, x.astype(dtype), expected=np.arccos(x).astype(dtype)) @@ -633,7 +630,7 @@ class UnaryOpsTest(xla_test.XLATestCase): expected=np.array([-1, 0, -2, -17, -43], dtype=dtype)) def testNumericOps(self): - for dtype in self.numeric_types: + for dtype in self.numeric_types - {np.int8, np.uint8}: self._assertOpOutputMatchesExpected( math_ops.abs, np.array([[2, -1]], dtype=dtype), diff --git a/tensorflow/compiler/tests/xla_ops_test.py b/tensorflow/compiler/tests/xla_ops_test.py index b2f026df6c0c28fcbceaa0493871bc12c2d23b1f..4cf88fc523735cc2d22e085afb83790c7ebb48e4 100644 --- a/tensorflow/compiler/tests/xla_ops_test.py +++ b/tensorflow/compiler/tests/xla_ops_test.py @@ -25,6 +25,7 @@ from tensorflow.compiler.tests import xla_test from tensorflow.compiler.tf2xla.python import xla from tensorflow.compiler.xla import xla_data_pb2 from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors from tensorflow.python.framework import function from tensorflow.python.ops import array_ops from tensorflow.python.platform import googletest @@ -97,9 +98,9 @@ class XlaOpsTest(xla_test.XLATestCase, parameterized.TestCase): args=(np.array([0xFFFFFFFF, 16], dtype=np.uint32), np.uint32(4)), expected=np.array([0xFFFFFFFF, 1], dtype=np.uint32)) - PRECISION_VALUES = (None, xla_data_pb2.PrecisionConfigProto.DEFAULT, - xla_data_pb2.PrecisionConfigProto.HIGH, - xla_data_pb2.PrecisionConfigProto.HIGHEST) + PRECISION_VALUES = (None, xla_data_pb2.PrecisionConfig.DEFAULT, + xla_data_pb2.PrecisionConfig.HIGH, + xla_data_pb2.PrecisionConfig.HIGHEST) @parameterized.parameters(*PRECISION_VALUES) def testConv(self, precision): @@ -120,7 +121,7 @@ class XlaOpsTest(xla_test.XLATestCase, parameterized.TestCase): dnums.output_spatial_dimensions.extend(range(2, 2 + num_spatial_dims)) precision_config = None if precision: - precision_config = xla_data_pb2.PrecisionConfigProto() + precision_config = xla_data_pb2.PrecisionConfig() precision_config.operand_precision.extend([precision, precision]) return xla.conv( lhs, @@ -151,7 +152,7 @@ class XlaOpsTest(xla_test.XLATestCase, parameterized.TestCase): dnums.rhs_batch_dimensions.append(0) precision_config = None if precision: - precision_config = xla_data_pb2.PrecisionConfigProto() + precision_config = xla_data_pb2.PrecisionConfig() precision_config.operand_precision.extend([precision, precision]) return xla.dot_general( lhs, @@ -180,7 +181,7 @@ class XlaOpsTest(xla_test.XLATestCase, parameterized.TestCase): dtype=dtype)) def testNeg(self): - for dtype in self.numeric_types: + for dtype in self.numeric_types - {np.uint8, np.int8}: self._assertOpOutputMatchesExpected( xla.neg, args=(np.array([1, 2, 3], dtype=dtype),), @@ -296,6 +297,44 @@ class XlaOpsTest(xla_test.XLATestCase, parameterized.TestCase): self._assertOpOutputMatchesExpected( lambda x: xla.transpose(x, [1, 0]), args=(v,), expected=v.T) + def testDynamicSlice(self): + for dtype in self.numeric_types: + self._assertOpOutputMatchesExpected( + xla.dynamic_slice, + args=(np.arange(1000, + dtype=np.int32).astype(dtype).reshape([10, 10, 10]), + np.array([5, 7, 3]), np.array([2, 3, 2])), + expected=np.array( + np.array([[[573, 574], [583, 584], [593, 594]], + [[673, 674], [683, 684], [693, 694]]]), + dtype=dtype)) + + def testDynamicSliceWithIncorrectStartIndicesShape(self): + with self.test_session() as session: + with self.test_scope(): + output = xla.dynamic_slice( + np.arange(1000, dtype=np.int32).reshape([10, 10, 10]), + np.array([5, 7]), np.array([2, 3, 4])) + with self.assertRaises(errors.InvalidArgumentError) as invalid_arg_error: + session.run(output) + self.assertRegexpMatches( + invalid_arg_error.exception.message, + (r'^start_indices must be a vector with length equal to input rank, ' + r'but input rank is 3 and start_indices has shape \[2\].*')) + + def testDynamicSliceWithIncorrectSizeIndicesShape(self): + with self.test_session() as session: + with self.test_scope(): + output = xla.dynamic_slice( + np.arange(1000, dtype=np.int32).reshape([10, 10, 10]), + np.array([5, 7, 3]), np.array([2, 3])) + with self.assertRaises(errors.InvalidArgumentError) as invalid_arg_error: + session.run(output) + self.assertRegexpMatches( + invalid_arg_error.exception.message, + (r'^size_indices must be a vector with length equal to input rank, ' + r'but input rank is 3 and size_indices has shape \[2\].*')) + if __name__ == '__main__': googletest.main() diff --git a/tensorflow/compiler/tests/xla_test.py b/tensorflow/compiler/tests/xla_test.py index 88827cb53bee7bb809d0163d6badcef17e59aa78..98a41981cf30917bc2054c19af5d8176bdfc9862 100644 --- a/tensorflow/compiler/tests/xla_test.py +++ b/tensorflow/compiler/tests/xla_test.py @@ -97,10 +97,23 @@ class XLATestCase(test.TestCase): ]) self._numeric_tf_types = set( self.int_tf_types | self._float_tf_types | self.complex_tf_types) - - self._all_types = set( - [dtype.as_numpy_dtype for dtype in self._all_tf_types]) + self.quantized_tf_types = set( + dtype for dtype in self._all_tf_types if dtype.is_quantized) + + # Quantized types don't have a numpy equivalent, include them in + # all_tf_types but not in all_types. + # TODO(b/115960798): Parametrize tests on TF types instead of numpy types + # and remove all_types. + self._all_types = set(dtype.as_numpy_dtype + for dtype in self._all_tf_types + if not dtype.is_quantized) self._int_types = set([dtype.as_numpy_dtype for dtype in self.int_tf_types]) + self.signed_int_types = set(dtype.as_numpy_dtype + for dtype in self.int_tf_types + if not dtype.is_unsigned) + self.unsigned_int_types = set(dtype.as_numpy_dtype + for dtype in self.int_tf_types + if dtype.is_unsigned) self._float_types = set( [dtype.as_numpy_dtype for dtype in self._float_tf_types]) self.complex_types = set([ diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 92e577bb7b930f5b9139e361cafb8628daede455..f0e7791e9811533502fae0d4dea5a2e1ca2cf33c 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -76,6 +76,7 @@ cc_library( deps = [ ":common", ":dump_graph", + ":functionalize_control_flow", ":tf2xla_proto", ":tf2xla_util", ":xla_compiler", @@ -188,9 +189,9 @@ cc_library( deps = [ ":common", ":dump_graph", - ":functionalize_control_flow", ":host_compute_metadata_proto", ":sharding_util", + ":side_effect_util", ":tf2xla_util", "//tensorflow/compiler/tf2xla/lib:util", "//tensorflow/compiler/xla:literal", @@ -214,6 +215,7 @@ cc_library( "//tensorflow/core:protos_all_cc", "//tensorflow/core:stream_executor_no_cuda", "@com_google_absl//absl/memory", + "@com_google_absl//absl/types:span", ], alwayslink = 1, ) @@ -239,6 +241,7 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/types:span", ], ) @@ -281,6 +284,7 @@ cc_library( deps = [ ":sharding_util", ":tf2xla_proto", + "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", @@ -289,6 +293,7 @@ cc_library( "//tensorflow/core:graph", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", ], ) @@ -303,6 +308,7 @@ tf_cc_test( "//tensorflow/cc:function_ops", "//tensorflow/cc:ops", "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:math_ops_op_lib", "//tensorflow/core:protos_all_cc", @@ -356,6 +362,7 @@ tf_cc_test( name = "xla_compiler_test", srcs = ["xla_compiler_test.cc"], deps = [ + ":side_effect_util", ":xla_compiler", "//tensorflow/cc:cc_ops", "//tensorflow/cc:function_ops", @@ -367,6 +374,7 @@ tf_cc_test( "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/service:cpu_plugin", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:core_cpu_internal", @@ -431,6 +439,7 @@ cc_library( "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/strings", ], ) @@ -472,6 +481,7 @@ cc_library( "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:graph", + "//tensorflow/core:lib", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", @@ -499,11 +509,23 @@ cc_library( "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:graph", + "//tensorflow/core:lib", "@com_google_absl//absl/memory", "@com_google_absl//absl/types:optional", ], ) +cc_library( + name = "functionalize_control_flow_pass_registration", + srcs = [ + "functionalize_control_flow_pass_registration.cc", + ], + deps = [ + ":functionalize_control_flow", + ], + alwayslink = 1, +) + cc_library( name = "functionalize_while", srcs = [ @@ -513,6 +535,7 @@ cc_library( "functionalize_while.h", ], deps = [ + ":functionalize_cond", ":functionalize_control_flow_util", ":tf2xla_util", "//tensorflow/compiler/jit:union_find", @@ -523,6 +546,7 @@ cc_library( "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:graph", + "//tensorflow/core:lib", "@com_google_absl//absl/memory", "@com_google_absl//absl/types:optional", ], @@ -537,6 +561,7 @@ tf_cc_test( "//tensorflow/cc:cc_ops", "//tensorflow/cc:cc_ops_internal", "//tensorflow/cc:function_ops", + "//tensorflow/cc:functional_ops", "//tensorflow/cc:ops", "//tensorflow/cc:resource_variable_ops", "//tensorflow/compiler/tf2xla/cc:xla_ops", @@ -587,6 +612,7 @@ cc_library( "//tensorflow/compiler/xla:status_macros", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", ], @@ -607,11 +633,11 @@ cc_library( srcs = ["resource_operation_table.cc"], hdrs = ["resource_operation_table.h"], deps = [ - "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:ops", - "//tensorflow/core:protos_all_cc", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/strings", ], ) @@ -625,6 +651,17 @@ tf_cc_test( "//tensorflow/core:test", "//tensorflow/core:test_main", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "side_effect_util", + srcs = ["side_effect_util.cc"], + hdrs = ["side_effect_util.h"], + deps = [ + "//tensorflow/core:core_cpu", "@com_google_absl//absl/strings", ], ) diff --git a/tensorflow/compiler/tf2xla/cc/BUILD b/tensorflow/compiler/tf2xla/cc/BUILD index ea8d1b3d14939d4f4fba598318200f71c2eb0270..adcdb6c8f762cb7ea68485167bf7fc8ccb343a51 100644 --- a/tensorflow/compiler/tf2xla/cc/BUILD +++ b/tensorflow/compiler/tf2xla/cc/BUILD @@ -30,14 +30,15 @@ cc_library( tf_gen_op_wrapper_cc( name = "xla_jit_op_gen", - out_ops_file = "ops/xla_jit_op", + include_internal_ops = 1, + out_ops_file = "ops/xla_jit_ops", deps = ["//tensorflow/compiler/jit/ops:xla_ops"], ) cc_library( name = "xla_jit_ops", - srcs = ["ops/xla_jit_op.cc"], - hdrs = ["ops/xla_jit_op.h"], + srcs = ["ops/xla_jit_ops.cc"], + hdrs = ["ops/xla_jit_ops.h"], deps = [ "//tensorflow/cc:const_op", "//tensorflow/cc:ops", diff --git a/tensorflow/compiler/tf2xla/const_analysis.cc b/tensorflow/compiler/tf2xla/const_analysis.cc index e8673d77903bd5a1a85412e9dfa86437f73d56bc..027ca6d2d2f616177d91d9d57d1ff373bab2a754 100644 --- a/tensorflow/compiler/tf2xla/const_analysis.cc +++ b/tensorflow/compiler/tf2xla/const_analysis.cc @@ -26,16 +26,9 @@ namespace tensorflow { // Backwards dataflow analysis that finds arguments to a graph that must be // compile-time constants. Status BackwardsConstAnalysis(const Graph& g, - std::vector* compile_time_const_args, - std::vector* compile_time_const_nodes) { - // Operators that don't look at the data of their inputs, just the shapes. - const std::unordered_set metadata_ops = { - "Rank", - "Shape", - "ShapeN", - "Size", - }; - + std::vector* compile_time_const_arg_indices, + std::vector* compile_time_const_nodes, + std::function edge_filter) { std::vector compile_time_const_nodes_impl; if (compile_time_const_nodes) { CHECK_EQ(compile_time_const_nodes->size(), g.num_node_ids()); @@ -45,12 +38,13 @@ Status BackwardsConstAnalysis(const Graph& g, } Status status; - auto visit = [&status, &metadata_ops, compile_time_const_nodes, - compile_time_const_args](Node* node) { + auto visit = [&](Node* node) { if (!status.ok()) return; // If this is a metadata-only op, don't propagate the const requirement. - if (metadata_ops.find(node->type_string()) != metadata_ops.end()) return; + if (XlaOpRegistry::IsMetadataOp(node->type_string())) { + return; + } // If this node must be const, and it isn't a metadata op, then all of its // parents must be const. @@ -59,13 +53,13 @@ Status BackwardsConstAnalysis(const Graph& g, int index; status = GetNodeAttr(node->attrs(), "index", &index); if (!status.ok()) return; - if (compile_time_const_args) { - (*compile_time_const_args)[index] = true; + if (compile_time_const_arg_indices) { + (*compile_time_const_arg_indices)[index] = true; } return; } for (const Edge* pred : node->in_edges()) { - if (!pred->IsControlEdge()) { + if (!pred->IsControlEdge() && edge_filter(*pred)) { (*compile_time_const_nodes)[pred->src()->id()] = true; } } @@ -88,7 +82,8 @@ Status BackwardsConstAnalysis(const Graph& g, for (Edge const* edge : node->in_edges()) { if (edge->dst_input() >= name_range->second.first && - edge->dst_input() < name_range->second.second) { + edge->dst_input() < name_range->second.second && + edge_filter(*edge)) { (*compile_time_const_nodes)[edge->src()->id()] = true; } } @@ -97,7 +92,8 @@ Status BackwardsConstAnalysis(const Graph& g, // Post-order traversal visits nodes in reverse topological order for an // acyclic graph. - DFS(g, {}, visit); + DFS(g, /*enter=*/{}, /*leave=*/visit, NodeComparatorName{}, + [](const Edge& edge) { return !edge.src()->IsNextIteration(); }); return status; } diff --git a/tensorflow/compiler/tf2xla/const_analysis.h b/tensorflow/compiler/tf2xla/const_analysis.h index af57e5a4033248e3fd32dabeda252c4ca0a44050..49b3c6d413c6b637fa825bf182be7cc36e49b6c8 100644 --- a/tensorflow/compiler/tf2xla/const_analysis.h +++ b/tensorflow/compiler/tf2xla/const_analysis.h @@ -32,9 +32,13 @@ namespace tensorflow { // // The ids of the nodes in `graph` that must be constant are returned in // `compile_time_const_nodes`, if `compile_time_const_nodes` is not null. -Status BackwardsConstAnalysis(const Graph& graph, +// +// Only propagate const-ness along edges for which `edge_filter` returns true. +Status BackwardsConstAnalysis(const Graph& g, std::vector* compile_time_const_arg_indices, - std::vector* compile_time_const_nodes); + std::vector* compile_time_const_nodes, + std::function edge_filter = + [](const Edge& e) { return true; }); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/dump_graph.cc b/tensorflow/compiler/tf2xla/dump_graph.cc index 24616c01c7e54b2e8662457ca6af23a0bc563e08..380c6a7e23da92d949b26876836b999bf6406c6c 100644 --- a/tensorflow/compiler/tf2xla/dump_graph.cc +++ b/tensorflow/compiler/tf2xla/dump_graph.cc @@ -18,8 +18,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/dump_graph.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/tf2xla/dump_graph_flags.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/mutex.h" @@ -52,9 +52,9 @@ string MakeUniqueFilename(string name) { string filename = name; if (count > 0) { - strings::StrAppend(&filename, "_", count); + absl::StrAppend(&filename, "_", count); } - strings::StrAppend(&filename, ".pbtxt"); + absl::StrAppend(&filename, ".pbtxt"); return filename; } @@ -69,7 +69,7 @@ string WriteTextProtoToUniqueFile( << proto_type << ": " << status; return "(unavailable)"; } - string filepath = strings::StrCat(dirname, "/", MakeUniqueFilename(name)); + string filepath = absl::StrCat(dirname, "/", MakeUniqueFilename(name)); status = WriteTextProto(Env::Default(), filepath, proto); if (!status.ok()) { LOG(WARNING) << "Failed to dump " << proto_type << " to file: " << filepath diff --git a/tensorflow/compiler/tf2xla/functionalize_cond.cc b/tensorflow/compiler/tf2xla/functionalize_cond.cc index b5667ca0d3ba35bea9da2d702b5b49fb38fe6f02..46649b8cc43016d4a62f49e20256c77ca8accc79 100644 --- a/tensorflow/compiler/tf2xla/functionalize_cond.cc +++ b/tensorflow/compiler/tf2xla/functionalize_cond.cc @@ -34,30 +34,16 @@ limitations under the License. #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/control_flow.h" #include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/lib/strings/strcat.h" using xla::StatusOr; namespace tensorflow { namespace functionalize_cond { -string DebugString(const CondStateMap::CondNode& node) { - return node.ToString(); -} - // TODO(jpienaar): Move to OutputTensor. string DebugString(const OutputTensor& tensor) { - return strings::StrCat(tensor.node->name(), ":", tensor.index); -} - -string DebugString(CondStateMap::CondId cond_state) { - if (cond_state == nullptr || cond_state->empty()) return "[]"; - return strings::StrCat( - "[", - absl::StrJoin(*cond_state, ", ", - [](string* output, const CondStateMap::CondNode& node) { - strings::StrAppend(output, node.ToString()); - }), - "]"); + return absl::StrCat(tensor.node->name(), ":", tensor.index); } string Branch_Name(BranchType b) { @@ -73,6 +59,24 @@ string Branch_Name(BranchType b) { } } +string DebugString(StateMap::CondId cond_state) { + if (cond_state == nullptr || cond_state->empty()) return "{}"; + using value_type = StateMap::CondState::value_type; + return absl::StrCat( + "{", + absl::StrJoin(*cond_state, ", ", + [](string* output, const value_type& pred_branch) { + const OutputTensor& pred = pred_branch.first; + const BranchType& branch = pred_branch.second; + if (branch == BranchType::kNeither) + absl::StrAppend(output, "d"); + else + absl::StrAppend(output, "s(", DebugString(pred), ",", + Branch_Name(branch), ")"); + }), + "}"); +} + // Returns the predicate of a switch. Status GetSwitchPredicate(const Node& switch_node, OutputTensor* pred) { const Edge* pred_edge; @@ -86,64 +90,65 @@ Status GetSwitchPredicate(const Node& switch_node, OutputTensor* pred) { return Status::OK(); } -CondStateMap::CondNode::CondNode(Type type, Node* switch_node, - BranchType branch) - : type(type), branch(branch) { - if (type == Type::kSwitch) { - TF_CHECK_OK(GetSwitchPredicate(*switch_node, &predicate)); - } -} - -string CondStateMap::CondNode::ToString() const { - switch (type) { - case Type::kSwitch: - return strings::StrCat("s(", DebugString(predicate), ",", - Branch_Name(branch), ")"); - case Type::kMerge: - return "m"; - case Type::kDead: - return "d"; - } +Status GetSwitchValue(const Node& switch_node, OutputTensor* val) { + const Edge* val_edge; + TF_RETURN_IF_ERROR(switch_node.input_edge(0, &val_edge)); + *val = OutputTensor(val_edge->src(), val_edge->src_output()); + return Status::OK(); } -bool CondStateMap::CondNode::operator==(const CondNode& other) const { - if (type != Type::kSwitch) return type == other.type; - return type == other.type && predicate == other.predicate && - branch == other.branch; +bool StateMap::OutputTensorLess::operator()(const OutputTensor& lhs, + const OutputTensor& rhs) const { + return (lhs.node->id() < rhs.node->id()) || + (lhs.node->id() == rhs.node->id() && lhs.index < rhs.index); } -bool CondStateMap::CondNode::operator!=(const CondNode& other) const { - return !(*this == other); -} +struct CondStateLess { + bool operator()(const StateMap::CondState::value_type& lhs, + const StateMap::CondState::value_type& rhs) const { + if (StateMap::OutputTensorLess().operator()(lhs.first, rhs.first)) + return true; + if (lhs.first.node->id() == rhs.first.node->id() && + lhs.first.index == rhs.first.index) + return lhs.second < rhs.second; + return false; + } +}; -CondStateMap::CondStateMap(Graph* graph) { +StateMap::StateMap(Graph* graph) { node_to_condid_map_.resize(graph->num_node_ids()); + node_to_ancestorid_map_.resize(graph->num_node_ids()); // Initialize the dead state (empty state is designated with a nullptr). - dead_id_ = GetUniqueId({CondNode(CondStateMap::CondNode::Type::kDead)}); + dead_id_ = GetCondId( + {std::make_pair(OutputTensor(nullptr, -1), BranchType::kNeither)}); } -bool CondStateMap::IsDead(CondStateMap::CondId id) const { - return id == dead_id_; -} +bool StateMap::IsDead(StateMap::CondId id) const { return id == dead_id_; } -bool CondStateMap::IsEmpty(CondStateMap::CondId id) const { - return id == nullptr; -} +bool StateMap::IsEmpty(StateMap::CondId id) const { return id == nullptr; } -size_t CondStateMap::CondHash::operator()( - const CondStateMap::CondNode& item) const { - return Hash64Combine(Hash64Combine(OutputTensor::Hash()(item.predicate), - hash()(item.branch)), - hash()(item.type)); +size_t StateMap::Hash::operator()(const StateMap::CondState& map) const { + if (map.empty()) return 0; + // Compute hash of the front element. + auto it = map.begin(); + size_t h = Hash64Combine(OutputTensor::Hash()(it->first), + hash()(it->second)); + for (++it; it != map.end(); ++it) { + // Combine the has with the different elements in the map. + h = Hash64Combine(h, Hash64Combine(OutputTensor::Hash()(it->first), + hash()(it->second))); + } + return h; } -size_t CondStateMap::CondHash::operator()( - const CondStateMap::CondState& vec) const { - if (vec.empty()) return 0; - size_t h = (*this)(vec.front()); - auto it = vec.begin(); - for (++it; it != vec.end(); ++it) { - h = Hash64Combine(h, (*this)(*it)); +size_t StateMap::Hash::operator()(const StateMap::AncestorState& map) const { + if (map.empty()) return 0; + // Compute hash of the front element. + auto it = map.begin(); + size_t h = hash()(*it); + for (++it; it != map.end(); ++it) { + // Combine the has with the different elements in the map. + h = Hash64Combine(h, hash()(*it)); } return h; } @@ -155,8 +160,8 @@ struct CondArgNode { : src(src), src_output(src_output) {} string ToString() const { - return strings::StrCat("src=", src->name(), ":", src_output, - " switches=", NodesToString(switches)); + return absl::StrCat("src=", src->name(), ":", src_output, + " switches=", NodesToString(switches)); } Node* src; @@ -167,58 +172,76 @@ struct CondArgNode { using CondArgNodes = std::vector; string DebugString(const CondArgNodes& nodes) { - return strings::StrCat( + return absl::StrCat( "[", absl::StrJoin(nodes, ", ", [](string* output, const CondArgNode& node) { - strings::StrAppend(output, node.ToString()); + absl::StrAppend(output, node.ToString()); }), "]"); } -CondStateMap::CondId CondStateMap::LookupId(const Node* node) const { +StateMap::CondId StateMap::LookupCondId(const Node* node) const { if (node->id() < node_to_condid_map_.size()) return node_to_condid_map_[node->id()]; - return added_node_mapping_.at(node->id()); + return added_node_condid_mapping_.at(node->id()); } -CondStateMap::CondId CondStateMap::GetUniqueId( - const CondStateMap::CondState& state) { +StateMap::CondId StateMap::GetCondId(const StateMap::CondState& state) { if (state.empty()) return nullptr; return &*condstate_set_.insert(state).first; } -const CondStateMap::CondState& CondStateMap::LookupState( - const Node* node) const { - return *LookupId(node); -} - -void CondStateMap::ResetId(const Node* node, CondStateMap::CondId id) { +void StateMap::ResetCondId(const Node* node, StateMap::CondId id) { if (node->id() < node_to_condid_map_.size()) node_to_condid_map_[node->id()] = id; else - added_node_mapping_[node->id()] = id; + added_node_condid_mapping_[node->id()] = id; } -void CondStateMap::MarkDead(const Node* node) { ResetId(node, dead_id_); } +StateMap::AncestorId StateMap::LookupAncestorId(const Node* node) const { + if (node->id() < node_to_ancestorid_map_.size()) + return node_to_ancestorid_map_[node->id()]; + return added_node_ancestorid_mapping_.at(node->id()); +} -string CondStateMap::CondStateToString(const Node* node) const { - return CondStateToString(LookupId(node)); +StateMap::AncestorId StateMap::GetAncestorId( + const StateMap::AncestorState& state) { + if (state.empty()) return nullptr; + return &*ancestorstate_set_.insert(state).first; } -string CondStateMap::CondStateToString(CondStateMap::CondId id) const { +void StateMap::ResetAncestorId(const Node* node, StateMap::AncestorId id) { + if (node->id() < node_to_ancestorid_map_.size()) + node_to_ancestorid_map_[node->id()] = id; + else + added_node_ancestorid_mapping_[node->id()] = id; +} + +void StateMap::MarkDead(const Node* node) { ResetCondId(node, dead_id_); } + +string StateMap::CondStateToString(const Node* node) const { + return CondStateToString(LookupCondId(node)); +} + +string StateMap::CondStateToString(StateMap::CondId id) const { return DebugString(id); } +string StateMap::AncestorStateToString(const Node* node) const { + if (auto id = LookupAncestorId(node)) return NodesToString(*id); + return "{}"; +} + FunctionalizeCond::FunctionalizeCond(Graph* graph, FunctionLibraryDefinition* library) - : cond_state_map_(graph), library_(library), graph_(graph) {} + : state_map_(graph), library_(library), graph_(graph) {} // Class representing the merge/switch nodes that will become a conditional. class Conditional { public: Conditional(OutputTensor predicate, FunctionalizeCond* parent, - CondStateMap* cond_state_map); + StateMap* cond_state_map); // Adds merge node that is part of this conditional. Status AddMerge(Node* m); @@ -247,6 +270,10 @@ class Conditional { // Adds switch node that is part of this conditional. Status AddSwitch(Node* s); + // Adds a switch node along the edge and rewire the edge to go via the switch. + Status AddSwitchNodeAlongEdge(const Edge* edge, BranchType branch, + Graph* graph); + // Internal name of conditional. The name is based on the first merge node // added. string name() const; @@ -255,7 +282,7 @@ class Conditional { FunctionalizeCond* parent_; // Mapping between nodes and their cond state. - CondStateMap* cond_state_map_; + StateMap* state_map_; // The predicate of the conditional. OutputTensor predicate_; @@ -292,8 +319,8 @@ class Conditional { }; Conditional::Conditional(OutputTensor predicate, FunctionalizeCond* parent, - CondStateMap* cond_state_map) - : parent_(parent), cond_state_map_(cond_state_map), predicate_(predicate) {} + StateMap* cond_state_map) + : parent_(parent), state_map_(cond_state_map), predicate_(predicate) {} Status Conditional::AddMerge(Node* m) { merges_.insert(m); @@ -343,7 +370,7 @@ Status Conditional::BuildArgumentNodes() { for (auto branch : {BranchType::kElseBranch, BranchType::kThenBranch}) { int branch_index = static_cast(branch); TF_RETURN_IF_ERROR( - NodeBuilder(strings::StrCat("_Arg", arg_count), + NodeBuilder(absl::StrCat("_Arg", arg_count), FunctionLibraryDefinition::kArgOp) .Attr("T", dtype) .Attr("index", arg_count) @@ -397,6 +424,35 @@ Status Conditional::BuildArgumentNodes() { return Status::OK(); } +Status Conditional::AddSwitchNodeAlongEdge(const Edge* edge, BranchType branch, + Graph* graph) { + // Previously we had edge: + // src:src_output ---- edge ----> dst:dst_input + // post this we have (in graph) + // src:src_output --> switch --- new_edge --> dst:dst_input + + // TODO(jpienaar): One could keep a map caching the extra switch nodes added + // to avoid adding another switch to feed a value for which a switch was + // already added. + Node* switch_node; + Node* src = edge->src(); + int src_output = edge->src_output(); + TF_RETURN_IF_ERROR( + NodeBuilder(graph->NewName(absl::StrCat(src->name(), "_added_switch")), + "Switch") + .Input(src, src_output) + .Input(const_cast(predicate_.node), predicate_.index) + .Finalize(graph, &switch_node)); + state_map_->ResetCondId(switch_node, state_map_->LookupCondId(src)); + state_map_->ResetAncestorId(switch_node, state_map_->LookupAncestorId(src)); + + Node* dst = edge->dst(); + int dst_input = edge->dst_input(); + graph->RemoveEdge(edge); + graph->AddEdge(switch_node, static_cast(branch), dst, dst_input); + return AddSwitch(switch_node); +} + Status Conditional::ExtractBodies(Graph* graph) { VLOG(2) << "Extracting bodies for " << name(); for (auto b : {BranchType::kElseBranch, BranchType::kThenBranch}) { @@ -405,16 +461,16 @@ Status Conditional::ExtractBodies(Graph* graph) { } auto find_branch = [&](const Edge* e) { - const auto& id = cond_state_map_->LookupId(e->src()); + const auto& id = state_map_->LookupCondId(e->src()); return IsSwitch(e->src()) ? BranchType(e->src_output()) - : cond_state_map_->FindBranchOf(id, predicate_); + : state_map_->FindBranchOf(id, predicate_); }; std::array, 2> stacks; VLOG(5) << "Merges: " << NodesToString(merges_); for (Node* m : merges_) { VLOG(5) << "For merge: " << m->DebugString() << " " - << cond_state_map_->CondStateToString(m); + << state_map_->CondStateToString(m); for (auto e : m->in_edges()) { if (e->IsControlEdge()) continue; BranchType branch = find_branch(e); @@ -422,7 +478,8 @@ Status Conditional::ExtractBodies(Graph* graph) { branch == BranchType::kElseBranch) << "Error: " << e->src()->name() << " is not on either then or else branch (" << Branch_Name(branch) - << ")."; + << ") for predicate " << DebugString(predicate_) << " [" + << DebugString(state_map_->LookupCondId(e->src())) << "]."; Node* src = e->src(); if (IsSwitch(src)) { // Switch node outputs and dependencies are handled separately. @@ -456,8 +513,8 @@ Status Conditional::ExtractBodies(Graph* graph) { if (IsMerge(dst)) continue; Node* src = e->src(); - auto dst_id = cond_state_map_->LookupId(dst); - auto src_id = cond_state_map_->LookupId(src); + auto dst_id = state_map_->LookupCondId(dst); + auto src_id = state_map_->LookupCondId(src); if (dst_id != src_id) { if (e->IsControlEdge()) { external_control_outputs_.push_back(e->src()); @@ -480,8 +537,11 @@ Status Conditional::ExtractBodies(Graph* graph) { } } - // Copying incomming edges to dst node. - for (const Edge* e : n->in_edges()) { + // Copying incomming edges to dst node. Iterate over a copy of the edges + // as they could be mutated during iteration. + std::vector in_edges(n->in_edges().begin(), + n->in_edges().end()); + for (const Edge* e : in_edges) { Node* src = e->src(); // Skip src/dst node. if (!src->IsOp()) continue; @@ -494,8 +554,8 @@ Status Conditional::ExtractBodies(Graph* graph) { } // Verify input is from the same context. - auto src_id = cond_state_map_->LookupId(src); - auto dst_id = cond_state_map_->LookupId(dst); + auto src_id = state_map_->LookupCondId(src); + auto dst_id = state_map_->LookupCondId(dst); if (IsMerge(dst) || src_id == dst_id) { // TODO(jpienaar): The merge case can be more strict. if (node_map.at(src->id()) == nullptr) { @@ -506,18 +566,25 @@ Status Conditional::ExtractBodies(Graph* graph) { external_control_inputs_.push_back(src); } else { // This shouldn't happen, this means we have an external data input - // not entering via a switch node. Work around this for constant - // nodes as some constant nodes are inserted without the required - // control context dominance. + // not entering via a switch node. Work around this by for + // * constant nodes copy them; + // * non-constant nodes, insert a switch along the edge; if (IsConstant(src)) { node_map.at(src->id()) = output->CopyNode(src); } else { - return errors::InvalidArgument( - "Graph contains node ", FormatNodeForError(*src), - " that feeds into node ", FormatNodeForError(*dst), - " but these nodes are in different control contexts (", - DebugString(src_id), " vs ", DebugString(dst_id), - " (detected during in edge testing)"); + StateMap::CondState state = *dst_id; + state.erase(predicate_); + if (state_map_->GetCondId(state) == src_id) { + TF_RETURN_IF_ERROR(AddSwitchNodeAlongEdge(e, branch, graph)); + continue; + } else { + return errors::InvalidArgument( + "Graph contains node ", FormatNodeForError(*src), + " that feeds into node ", FormatNodeForError(*dst), + " but these nodes are in different control contexts (", + DebugString(src_id), " vs ", DebugString(dst_id), + " (detected during in edge testing)"); + } } } @@ -572,7 +639,7 @@ Status Conditional::ExtractBodies(Graph* graph) { Status Conditional::BuildIfNode(Graph* graph, FunctionLibraryDefinition* library) { VLOG(2) << "Build cond function for " << name(); - NodeDefBuilder builder(name(), "If"); + NodeDefBuilder builder(name(), "If", library); const string branch_name[] = {"else_branch", "then_branch"}; for (auto branch : {BranchType::kElseBranch, BranchType::kThenBranch}) { int branch_index = static_cast(branch); @@ -580,8 +647,8 @@ Status Conditional::BuildIfNode(Graph* graph, int64 id = ++sequence_num; NameAttrList body_name; - body_name.set_name(strings::StrCat("_functionalize_if_", - branch_name[branch_index], "_", id)); + body_name.set_name( + absl::StrCat("_functionalize_if_", branch_name[branch_index], "_", id)); VLOG(3) << "FunctionalizeControlFlow (" << branch_name[branch_index] << "): " @@ -628,6 +695,12 @@ Status Conditional::BuildIfNode(Graph* graph, VLOG(3) << "Build output type: " << DataTypeVectorString(out_type); builder.Attr("Tcond", DT_BOOL); + string outside_compilation; + if (GetNodeAttr(predicate_.node->def(), kXlaOutsideCompilationAttrName, + &outside_compilation) + .ok()) { + builder.Attr(kXlaOutsideCompilationAttrName, outside_compilation); + } builder.Device(predicate_.node->assigned_device_name()); // Conditional should be the first input ... builder.Input(NodeDefBuilder::NodeOut(predicate_.node->name(), @@ -639,7 +712,8 @@ Status Conditional::BuildIfNode(Graph* graph, VLOG(3) << "Build If node"; NodeDef if_def; TF_RETURN_IF_ERROR(builder.Finalize(&if_def)); - TF_ASSIGN_OR_RETURN(if_node_, parent_->AddIfNode(if_def, *merges_.begin())); + TF_ASSIGN_OR_RETURN(if_node_, + parent_->AddIfNode(if_def, *merges_.begin(), predicate_)); return Status::OK(); } @@ -699,7 +773,8 @@ Status Conditional::AddOutputEdges(Graph* graph) { Status Conditional::BuildAndReplace(Graph* graph, FunctionLibraryDefinition* library) { - VLOG(1) << "Build If and replace merge nodes " << name(); + VLOG(1) << "Build If and replace merge nodes " + << NodesToString(this->merges_); if (replaced_) return Status::OK(); TF_RETURN_IF_ERROR(ExtractBodies(graph)); @@ -719,7 +794,6 @@ Status Conditional::BuildAndReplace(Graph* graph, TF_RETURN_IF_ERROR(AddInputEdges(graph)); TF_RETURN_IF_ERROR(AddOutputEdges(graph)); TF_RETURN_IF_ERROR(parent_->PropagateUpdatedState(if_node_)); - for (Node* m : merges_) cond_state_map_->MarkDead(m); // Check that the if_node doesn't feed into itself. TF_RETURN_WITH_CONTEXT_IF_ERROR( @@ -732,31 +806,7 @@ Status Conditional::BuildAndReplace(Graph* graph, string Conditional::name() const { CHECK(!merges_.empty()); - return strings::StrCat((*merges_.begin())->name(), "_if"); -} - -bool CondStateMap::ScopeIn(CondStateMap::CondId id, - CondStateMap::CondId* scope) { - if (id == nullptr) { - *scope = nullptr; - return true; - } - CondState state; - for (const CondNode& node : *id) { - if (node.type == CondNode::Type::kSwitch) { - state.push_back(node); - } - if (node.type == CondNode::Type::kMerge) { - if (state.empty()) { - return false; - } - DCHECK(state.back().type == CondNode::Type::kSwitch && - state.back().branch == BranchType::kBoth); - state.pop_back(); - } - } - *scope = GetUniqueId(state); - return true; + return absl::StrCat((*merges_.begin())->name(), "_if"); } Status FunctionalizeCond::AddIdentityNode(const Node* replacee, Node* if_node, @@ -765,25 +815,35 @@ Status FunctionalizeCond::AddIdentityNode(const Node* replacee, Node* if_node, TF_RETURN_IF_ERROR(NodeBuilder(replacee->name(), "Identity") .Input(if_node, port) .Finalize(graph_, &id)); - cond_state_map_.ResetId(id, cond_state_map_.LookupId(if_node)); + state_map_.ResetCondId(id, state_map_.LookupCondId(if_node)); + state_map_.ResetAncestorId(id, state_map_.LookupAncestorId(if_node)); return Status::OK(); } StatusOr FunctionalizeCond::AddIfNode(const NodeDef& def, - const Node* replacee) { + const Node* replacee, + const OutputTensor& predicate) { Status status; Node* ret = graph_->AddNode(def, &status); TF_RETURN_IF_ERROR(status); - CondStateMap::CondState state = cond_state_map_.LookupState(replacee); - state.pop_back(); VLOG(1) << "Adding If for " << replacee->name(); - cond_state_map_.ResetId(ret, cond_state_map_.GetUniqueId(state)); + StateMap::CondId id = state_map_.LookupCondId(replacee); + if (id) { + StateMap::CondState state = *id; + state.erase(predicate); + state_map_.ResetCondId(ret, state_map_.GetCondId(state)); + } else { + state_map_.ResetCondId(ret, nullptr); + } + + state_map_.ResetAncestorId(ret, state_map_.LookupAncestorId(replacee)); + return ret; } Status FunctionalizeCond::PropagateUpdatedState(const Node* replacee) { VLOG(2) << "Propagating update state for " << replacee->name() << " " - << cond_state_map_.CondStateToString(replacee); + << state_map_.CondStateToString(replacee); // Redo topological sort as the order could have changed. // TODO(jpienaar): The original topological order could also be updated // dynamically if needed. @@ -801,10 +861,10 @@ Status FunctionalizeCond::PropagateUpdatedState(const Node* replacee) { if (changed.find(*it) != changed.end()) { // Update the node state. Node* n = *it; - CondStateMap::CondId old_state = cond_state_map_.LookupId(n); - cond_state_map_.ResetId(n, nullptr); + StateMap::CondId old_state = state_map_.LookupCondId(n); + state_map_.ResetCondId(n, nullptr); TF_RETURN_IF_ERROR(DetermineCondState(n)); - if (cond_state_map_.LookupId(n) != old_state) { + if (state_map_.LookupCondId(n) != old_state) { for (auto out : n->out_nodes()) if (out->IsOp()) changed.insert(out); } @@ -825,127 +885,44 @@ BranchType MeetBranch(const BranchType& lhs, const BranchType& rhs) { return BranchType::kNeither; } -CondStateMap::ContainsResult CondStateMap::LhsHoldsWhereverRhsHolds( - CondStateMap::CondId lhs, CondStateMap::CondId rhs) { - CondId lhs_scope; - CondId rhs_scope; - bool could_determine_scope = ScopeIn(lhs, &lhs_scope); - could_determine_scope = could_determine_scope && ScopeIn(rhs, &rhs_scope); - if (!could_determine_scope) return kIncomparable; - - // Returns whether a contains b. - auto contains = [&](CondId a, CondId b) { - // Handle empty states. - if (a == nullptr && b != nullptr) return true; - if (a == nullptr && b == nullptr) return true; - if (a != nullptr && b == nullptr) return false; - - if (a->size() > b->size()) return false; - auto a_it = a->begin(); - auto b_it = b->begin(); - while (a_it != a->end()) { - if (*a_it != *b_it) { - if (!(a_it->predicate == b_it->predicate)) return false; - BranchType mb = MeetBranch(a_it->branch, b_it->branch); - if (mb != b_it->branch) return false; - } - ++a_it; - ++b_it; - } - return true; - }; - - bool lhs_contains_rhs = contains(lhs_scope, rhs_scope); - bool rhs_contains_lhs = contains(rhs_scope, lhs_scope); - if (lhs_contains_rhs && rhs_contains_lhs) return kEqual; - if (lhs_contains_rhs) return kLhsContainsRhs; - if (rhs_contains_lhs) return kRhsContainsLhs; - return kIncomparable; -} - -BranchType CondStateMap::FindBranchOf(CondId id, OutputTensor predicate) const { +BranchType StateMap::FindBranchOf(CondId id, OutputTensor predicate) const { if (IsEmpty(id)) return BranchType::kNeither; - absl::optional b; const CondState& nodes = *id; - for (auto it = nodes.rbegin(); it != nodes.rend(); ++it) { - if (it->type == CondStateMap::CondNode::Type::kSwitch && - it->predicate == predicate) { - if (b.has_value()) { - b = MeetBranch(*b, it->branch); - } else { - b = it->branch; - } - if (*b == BranchType::kNeither) { - LOG(FATAL) << "Inconsistent state for node: " << DebugString(id); - } - } - } - return b.has_value() ? *b : BranchType::kNeither; + auto it = nodes.find(predicate); + if (it == nodes.end()) return BranchType::kNeither; + return it->second; } -StatusOr FunctionalizeCond::JoinCondStatesNonMerge( - CondStateMap::CondId src, CondStateMap::CondId dst) { - VLOG(4) << "Joining src=" << DebugString(src) << " [" << src +StatusOr FunctionalizeCond::JoinCondStatesNonMerge( + StateMap::CondId src, StateMap::CondId dst) { + VLOG(5) << "Joining src=" << DebugString(src) << " [" << src << "] and dst=" << DebugString(dst) << " [" << dst << "]"; - if (cond_state_map_.IsEmpty(dst) || cond_state_map_.IsDead(src)) return src; - if (cond_state_map_.IsDead(dst)) return dst; + if (state_map_.IsEmpty(dst) || state_map_.IsDead(src)) return src; + if (state_map_.IsDead(dst) || state_map_.IsEmpty(src)) return dst; // Nothing to do if the CondState is the same. if (src == dst) return src; - CondStateMap::CondId src_scope; - CondStateMap::CondId dst_scope; - if (!cond_state_map_.ScopeIn(src, &src_scope)) - return errors::Unimplemented( - "Predicates that must hold for node to execute are invalid! ", - DebugString(src)); - if (!cond_state_map_.ScopeIn(dst, &dst_scope)) - return errors::Unimplemented( - "Predicates that must hold for node to execute are invalid! ", - DebugString(dst)); - - auto result = cond_state_map_.LhsHoldsWhereverRhsHolds(src_scope, dst_scope); - switch (result) { - case CondStateMap::kIncomparable: - return errors::InvalidArgument( - "Graph contains node with inputs predicated on incompatible " - "predicates: ", - DebugString(src), " and ", DebugString(dst)); - case CondStateMap::kEqual: - // If both respect the same predicates, propagate the longer constraint. - if ((src != nullptr && dst == nullptr) || - (src != nullptr && dst != nullptr && src->size() > dst->size())) - return src; - else - return dst; - case CondStateMap::kLhsContainsRhs: - // src contains dst, so dst is already more restrictive. - return dst; - case CondStateMap::kRhsContainsLhs: - // dst contains src, so src is more restrictive. - return src; - } -} - -StatusOr -FindThenElseSwitchForPredicate(const OutputTensor& pred, - CondStateMap::CondId id) { - for (auto it = id->begin(); it != id->end(); ++it) { - // Along every path one there can be only one instance of a then or else - // switch for a given predicate, so return once found. - if (it->type == CondStateMap::CondNode::Type::kSwitch && - it->predicate == pred && - (it->branch == BranchType::kThenBranch || - it->branch == BranchType::kElseBranch)) - return it; + StateMap::CondState both = *src; + for (const auto& kv : *dst) { + auto it = both.find(kv.first); + if (it == both.end()) { + both.insert(kv); + } else { + if (it->second != kv.second) { + return errors::InvalidArgument( + "Graph contains node with inputs predicated on incompatible " + "predicates: ", + DebugString(src), " and ", DebugString(dst)); + } + } } - return errors::Internal("Unable to find then/else branch with predicate ", - DebugString(pred), " for ", DebugString(id)); + return state_map_.GetCondId(both); } -StatusOr FunctionalizeCond::JoinCondStatesMerge( - CondStateMap::CondId src, CondStateMap::CondId dst) { +StatusOr FunctionalizeCond::JoinCondStatesMerge( + Node* merge, StateMap::CondId src, StateMap::CondId dst) { // Determine the flow state when joining two states for a merge // node. Combining the two states for a merge node is effectively performing a // disjunction of the states along the different input edges. For a merge that @@ -956,91 +933,56 @@ StatusOr FunctionalizeCond::JoinCondStatesMerge( // followed by s(p, both). VLOG(4) << "Joining (for merge) " << DebugString(src) << " and " << DebugString(dst); - if (cond_state_map_.IsEmpty(dst)) return src; - - if (cond_state_map_.IsDead(src)) return src; - if (cond_state_map_.IsDead(dst)) return dst; - - CondStateMap::CondId src_scope; - CondStateMap::CondId dst_scope; - if (!cond_state_map_.ScopeIn(src, &src_scope)) - return errors::Unimplemented( - "Predicates that must hold for node to execute are invalid! ", - DebugString(src)); - if (!cond_state_map_.ScopeIn(dst, &dst_scope)) - return errors::Unimplemented( - "Predicates that must hold for node to execute are invalid! ", - DebugString(dst)); - - TF_RET_CHECK(src_scope != nullptr && dst_scope != nullptr) - << "Illegal merge inputs from outer scope: src=" << DebugString(src) - << " dst=" << DebugString(dst); - auto src_it = src_scope->begin(); - auto dst_it = dst_scope->begin(); - - // Find branch divergent condition. - OutputTensor pred; - while (src_it != src_scope->end() && dst_it != dst_scope->end()) { - if (*src_it != *dst_it) { - VLOG(5) << "Diverges with: " << DebugString(*src_it) << " and " - << DebugString(*dst_it); - if (!(src_it->predicate == dst_it->predicate)) { - return errors::InvalidArgument( - "Unable to find common predicate which holds for one input " - "but not the other of the merge node."); - } - pred = src_it->predicate; - break; - } - ++src_it; - ++dst_it; - } - - if (pred.node == nullptr) - return errors::InvalidArgument("Unable to determine predicate for merge."); - - TF_ASSIGN_OR_RETURN(auto div_src_it, - FindThenElseSwitchForPredicate(pred, src)); - TF_ASSIGN_OR_RETURN(auto div_dst_it, - FindThenElseSwitchForPredicate(pred, dst)); - TF_RET_CHECK(*div_src_it != *div_dst_it); - - CondStateMap::CondState result; - // Populate result with the longest/most restrictive path up to the divergent - // node. For example, if the one input is `[switch(pred:0, then)]` and the - // other is `[switch(pred:0, both), merge, switch(pred:0, else)]` (as created - // in gradient of cond test), then the resultant state here should be - // `[switch(pred:0, both), merge, switch(pred:0, both)]`. - if (std::distance(src->begin(), div_src_it) > - std::distance(dst->begin(), div_dst_it)) { - result.assign(src->begin(), std::next(div_src_it)); + if (state_map_.IsEmpty(dst)) return src; + + if (state_map_.IsDead(src)) return src; + if (state_map_.IsDead(dst)) return dst; + + std::vector diff; + StateMap::CondState merged; + std::set_symmetric_difference(src->begin(), src->end(), dst->begin(), + dst->end(), std::back_inserter(diff), + CondStateLess()); + std::set_intersection(src->begin(), src->end(), dst->begin(), dst->end(), + std::inserter(merged, merged.begin()), CondStateLess()); + + // Update mapping from merge node to predicate. + if (diff.size() == 2) { + auto pred = diff[0].first; + bool different_branches = (diff[0].second != diff[1].second) && + (diff[0].second == BranchType::kThenBranch || + diff[0].second == BranchType::kElseBranch) && + (diff[1].second == BranchType::kThenBranch || + diff[1].second == BranchType::kElseBranch); + if (!(pred == diff[1].first) || !different_branches) + return errors::InvalidArgument( + "Unable to determine predicate for merge node"); + merge_to_predicate_[merge] = pred; } else { - result.assign(dst->begin(), std::next(div_dst_it)); + return errors::InvalidArgument( + "Merge of two inputs that differ on more than one predicate ", + DebugString(src), " and ", DebugString(dst)); } - result.back().branch = BranchType::kBoth; - return cond_state_map_.GetUniqueId(result); + + return state_map_.GetCondId(merged); } -CondStateMap::CondId FunctionalizeCond::StateAlongEdge(const Edge* e) { +StateMap::CondId FunctionalizeCond::StateAlongEdge(const Edge* e) { Node* src = e->src(); - CondStateMap::CondId id = cond_state_map_.LookupId(e->src()); - if (IsMerge(src)) { - CondStateMap::CondState state; - if (id != nullptr) state = *id; - state.emplace_back(CondStateMap::CondNode::Type::kMerge); - return cond_state_map_.GetUniqueId(state); - } + StateMap::CondId id = state_map_.LookupCondId(e->src()); + + // Dead nodes only propagate dead state. + if (state_map_.IsDead(id)) return id; + if (IsSwitch(src)) { - CondStateMap::CondState state; + StateMap::CondState state; if (id != nullptr) state = *id; - if (e->IsControlEdge()) { - state.emplace_back(CondStateMap::CondNode::Type::kSwitch, src, - BranchType::kBoth); - } else { - state.emplace_back(CondStateMap::CondNode::Type::kSwitch, src, - BranchType(e->src_output())); + OutputTensor predicate; + TF_CHECK_OK(GetSwitchPredicate(*src, &predicate)); + if (!e->IsControlEdge()) { + state[predicate] = BranchType(e->src_output()); } - return cond_state_map_.GetUniqueId(state); + return state_map_.GetCondId(state); } return id; } @@ -1049,22 +991,21 @@ Status FunctionalizeCond::DetermineCondStateMerge(Node* dst) { // Only Merge nodes with two inputs are supported, but if this is a redundant // merge, then the dead edge may already have been removed (if due to a // switch) and so the input count would be incorrect. - if (cond_state_map_.IsDead(cond_state_map_.LookupId(dst))) - return Status::OK(); + if (state_map_.IsDead(state_map_.LookupCondId(dst))) return Status::OK(); int data_inputs = 0; for (auto e : dst->in_edges()) { Node* src = e->src(); VLOG(5) << "Processing forward flow for merge: " << e->DebugString() << " " - << cond_state_map_.CondStateToString(src); + << state_map_.CondStateToString(src); if (!src->IsOp()) continue; if (!e->IsControlEdge()) ++data_inputs; - CondStateMap::CondId prop = StateAlongEdge(e); - auto id_or = JoinCondStatesMerge(prop, cond_state_map_.LookupId(dst)); + StateMap::CondId prop = StateAlongEdge(e); + auto id_or = JoinCondStatesMerge(dst, prop, state_map_.LookupCondId(dst)); TF_RETURN_WITH_CONTEXT_IF_ERROR(id_or.status(), "for node ", FormatNodeForError(*dst)); - cond_state_map_.ResetId(dst, id_or.ValueOrDie()); + state_map_.ResetCondId(dst, id_or.ValueOrDie()); } // Incomplete Merge nodes are not supported. @@ -1076,27 +1017,20 @@ Status FunctionalizeCond::DetermineCondStateMerge(Node* dst) { return Status::OK(); } -Status FunctionalizeCond::DetermineCondState(Node* dst) { - // The logic for the merge and non-merge case differ: for non-merge it is - // the most restrictive CondState, while for merge nodes the - // resultant state is less restrictive than either. - if (IsMerge(dst)) { - TF_RETURN_IF_ERROR(DetermineCondStateMerge(dst)); - } else { - // Handle non-merge join. - for (auto e : dst->in_edges()) { - VLOG(5) << "Processing forward flow for: " << e->DebugString() << " " - << cond_state_map_.CondStateToString(dst); - Node* src = e->src(); - if (!src->IsOp()) continue; - - // Joining the state between the current and propagated state. - CondStateMap::CondId prop = StateAlongEdge(e); - auto id_or = JoinCondStatesNonMerge(prop, cond_state_map_.LookupId(dst)); - TF_RETURN_WITH_CONTEXT_IF_ERROR(id_or.status(), "for node ", - FormatNodeForError(*dst)); - cond_state_map_.ResetId(dst, id_or.ValueOrDie()); - } +Status FunctionalizeCond::DetermineCondStateNonMerge(Node* dst) { + // Handle non-merge join. + for (auto e : dst->in_edges()) { + VLOG(4) << "Processing forward flow for: " << e->DebugString() << " " + << state_map_.CondStateToString(dst); + Node* src = e->src(); + if (!src->IsOp()) continue; + + // Joining the state between the current and propagated state. + StateMap::CondId prop = StateAlongEdge(e); + auto id_or = JoinCondStatesNonMerge(prop, state_map_.LookupCondId(dst)); + TF_RETURN_WITH_CONTEXT_IF_ERROR(id_or.status(), "for node ", + FormatNodeForError(*dst)); + state_map_.ResetCondId(dst, id_or.ValueOrDie()); } return Status::OK(); } @@ -1104,8 +1038,7 @@ Status FunctionalizeCond::DetermineCondState(Node* dst) { Status FunctionalizeCond::RemoveRedundantMerge(Node* node) { // Handle redundant merge nodes. A merge node is considered redundant if // one input edge is dead while the other has a value. - if (!cond_state_map_.IsDead(cond_state_map_.LookupId(node))) - return Status::OK(); + if (!state_map_.IsDead(state_map_.LookupCondId(node))) return Status::OK(); const Edge* non_dead_edge = nullptr; for (auto e : node->in_edges()) { @@ -1113,8 +1046,8 @@ Status FunctionalizeCond::RemoveRedundantMerge(Node* node) { Node* src = e->src(); // Handle merge with dead state. - const auto& src_id = cond_state_map_.LookupId(src); - if (!cond_state_map_.IsDead(src_id)) { + const auto& src_id = state_map_.LookupCondId(src); + if (!state_map_.IsDead(src_id)) { non_dead_edge = e; break; } @@ -1124,8 +1057,7 @@ Status FunctionalizeCond::RemoveRedundantMerge(Node* node) { return errors::InvalidArgument("Merge node ", FormatNodeForError(*node), " has no non-dead inputs."); } - cond_state_map_.MarkDead(node); - delete_nodes_.push_back(node->id()); + state_map_.MarkDead(node); VLOG(5) << "removing redundant merge: " << node->name(); while (!node->out_edges().empty()) { const Edge* oe = *node->out_edges().begin(); @@ -1149,16 +1081,33 @@ Status FunctionalizeCond::RemoveRedundantSwitch(Node* node) { // along one. The checking of predicate is based on the exact predicate // (rather than boolean equivalence) and aimed at redundant switches as // currently generated by gradient code. + StateMap::CondId dst_id = state_map_.LookupCondId(node); + if (state_map_.IsDead(dst_id)) return Status::OK(); + + BranchType b; OutputTensor pred; TF_RETURN_IF_ERROR(GetSwitchPredicate(*node, &pred)); - auto dst_id = cond_state_map_.LookupId(node); - BranchType b = cond_state_map_.FindBranchOf(dst_id, pred); + // Determine if we are already on a branch where the switch predicate is - // true/false. - if (b != BranchType::kThenBranch && b != BranchType::kElseBranch) - return Status::OK(); + // true/false. Consider both the data and predicate to determine if the + // node is redundant (skipping over identity node). + b = state_map_.FindBranchOf(dst_id, pred); + if (b != BranchType::kThenBranch && b != BranchType::kElseBranch) { + OutputTensor val; + const Edge* e; + TF_RETURN_IF_ERROR(node->input_edge(0, &e)); + val = OutputTensor(e->src(), e->src_output()); + while (IsIdentity(val.node)) { + TF_RETURN_IF_ERROR(val.node->input_edge(0, &e)); + val = OutputTensor(e->src(), e->src_output()); + } + b = state_map_.FindBranchOf(dst_id, val); + if (b != BranchType::kThenBranch && b != BranchType::kElseBranch) + return Status::OK(); + } - VLOG(5) << "Redundant switch " << node->name(); + VLOG(5) << "Redundant switch " << node->name() << " " << Branch_Name(b) << " " + << DebugString(dst_id); const Edge* value_edge; TF_RETURN_IF_ERROR(node->input_edge(0, &value_edge)); Node* val_node = value_edge->src(); @@ -1171,20 +1120,19 @@ Status FunctionalizeCond::RemoveRedundantSwitch(Node* node) { graph_->RemoveEdge(e); if (switch_branch == Graph::kControlSlot) { if (IsMerge(dst_node)) { - auto id_or = - JoinCondStatesMerge(dst_id, cond_state_map_.LookupId(dst_node)); + auto id_or = JoinCondStatesMerge(dst_node, dst_id, + state_map_.LookupCondId(dst_node)); TF_RETURN_WITH_CONTEXT_IF_ERROR(id_or.status(), "for node ", FormatNodeForError(*dst_node)); - cond_state_map_.ResetId(dst_node, id_or.ValueOrDie()); + state_map_.ResetCondId(dst_node, id_or.ValueOrDie()); } else { auto id_or = - JoinCondStatesNonMerge(dst_id, cond_state_map_.LookupId(dst_node)); + JoinCondStatesNonMerge(dst_id, state_map_.LookupCondId(dst_node)); TF_RETURN_IF_ERROR(id_or.status()); - cond_state_map_.ResetId(dst_node, id_or.ValueOrDie()); + state_map_.ResetCondId(dst_node, id_or.ValueOrDie()); } } else if (BranchType(switch_branch) != b) { - cond_state_map_.MarkDead(dst_node); - delete_nodes_.push_back(dst_node->id()); + state_map_.MarkDead(dst_node); continue; } graph_->AddEdge( @@ -1195,37 +1143,103 @@ Status FunctionalizeCond::RemoveRedundantSwitch(Node* node) { return Status::OK(); } -Status FunctionalizeCond::DetermineCondStates( - std::vector rev_topo_order) { +Status FunctionalizeCond::DetermineStates(std::vector rev_topo_order) { // The state that is propagated along the given edge. for (auto it = rev_topo_order.rbegin(); it != rev_topo_order.rend(); ++it) { Node* dst = *it; TF_RETURN_IF_ERROR(DetermineCondState(dst)); + TF_RETURN_IF_ERROR(DetermineAncestorState(dst)); if (IsSwitch(dst)) TF_RETURN_IF_ERROR(RemoveRedundantSwitch(dst)); if (IsMerge(dst)) TF_RETURN_IF_ERROR(RemoveRedundantMerge(dst)); - VLOG(5) << dst->name() << " :: " << cond_state_map_.CondStateToString(dst); + VLOG(5) << dst->name() << " :: " << state_map_.CondStateToString(dst) + << " @ " << state_map_.AncestorStateToString(dst); + if (VLOG_IS_ON(10)) DumpGraphWithCondState("it"); } return Status::OK(); } -void FunctionalizeCond::DeleteReachableNodes() { +Status FunctionalizeCond::DetermineAncestorState(Node* dst) { + StateMap::AncestorId id = nullptr; + StateMap::AncestorState state; + + auto insert = [&](StateMap::AncestorId id, Node* src) { + auto other_id = state_map_.LookupAncestorId(src); + if (other_id != id && other_id != nullptr) { + state.insert(other_id->begin(), other_id->end()); + } + if (IsSwitch(src) || IsMerge(src)) { + state.insert(src); + } + return state_map_.GetAncestorId(state); + }; + + // Compute the union of all the switch/merge nodes that affects the input of + // dst. + for (auto e : dst->in_edges()) { + Node* src = e->src(); + id = insert(id, src); + } + state_map_.ResetAncestorId(dst, id); + return Status::OK(); +} + +void FunctionalizeCond::DeleteReachableAndDeadNodes( + const std::vector& switch_ids, const std::vector& merge_order) { // Delete all nodes that have been extracted or are reachable from // deleted/dead nodes. The input and outgoing edges should have already been // removed. + std::deque delete_nodes; std::vector deleted(graph_->num_node_ids(), false); // Don't try to delete source or sink nodes. deleted[graph_->kSourceId] = true; deleted[graph_->kSinkId] = true; - while (!delete_nodes_.empty()) { - int d_id = delete_nodes_.front(); - delete_nodes_.pop_front(); + + // All remaining Switch nodes are not reachable from a Merge node and + // removed. This is to account for dead Switch nodes. + for (int s_id : switch_ids) { + Node* s = graph_->FindNodeId(s_id); + if (s == nullptr) continue; + for (const Edge* e : s->out_edges()) { + // Control outputs of switch nodes (which are unconditionally executed if + // the switch is) are not removed as they need not be part of a + // conditional. + if (!e->IsControlEdge()) delete_nodes.push_back(e->dst()->id()); + } + deleted[s_id] = true; + graph_->RemoveNode(s); + } + + // All merge nodes should have been transformed at this point and we remove + // them from the graph here. + for (Node* m : merge_order) { + for (const Edge* e : m->out_edges()) { + // Similar to control outputs of switch nodes don't remove control + // outputs of merge nodes. + // TODO(jpienaar): Check cases where output edges still exist here vs + // being removed in AddOutputEdges. + if (!e->IsControlEdge()) delete_nodes.push_back(e->dst()->id()); + } + deleted[m->id()] = true; + graph_->RemoveNode(m); + } + + // Enqueue all the dead nodes. + for (Node* n : graph_->nodes()) { + if (state_map_.IsDead(state_map_.LookupCondId(n))) { + delete_nodes.push_back(n->id()); + } + } + + while (!delete_nodes.empty()) { + int d_id = delete_nodes.front(); + delete_nodes.pop_front(); if (deleted[d_id]) continue; Node* d = graph_->FindNodeId(d_id); // Switch and Merge nodes could have been deleted already. if (d == nullptr) continue; for (const Edge* e : d->out_edges()) { - delete_nodes_.push_back(e->dst()->id()); + delete_nodes.push_back(e->dst()->id()); } deleted[d_id] = true; graph_->RemoveNode(d); @@ -1239,16 +1253,8 @@ void FunctionalizeCond::SortMergeNodes(std::vector* merge_order) { inner_to_outer_merge_order.reserve(merge_order->size()); for (auto it = merge_order->rbegin(); it != merge_order->rend(); ++it) { Node* merge = *it; - CondStateMap::CondId id = cond_state_map_.LookupId(merge); - int depth = 0; - for (auto cond_node_it = id->begin(); cond_node_it != id->end(); - ++cond_node_it) { - if (cond_node_it->type == CondStateMap::CondNode::Type::kSwitch && - (cond_node_it->branch == BranchType::kThenBranch || - cond_node_it->branch == BranchType::kElseBranch)) { - ++depth; - } - } + StateMap::CondId id = state_map_.LookupCondId(merge); + int depth = id != nullptr ? id->size() : 0; inner_to_outer_merge_order.emplace_back(depth, merge); } std::stable_sort( @@ -1271,10 +1277,10 @@ Status FunctionalizeCond::FunctionalizeInternal() { // determine deeper equivalence). We shall refer to this structure as the // CondState; // 3. Sort the merge nodes by nesting depth; - // 4. Extract merge nodes together that have the same CondState and whose - // input nodes have the same state from the innermost to the outermost into - // IfOps; Note: In the above only nodes paths that converge to a merge node - // will be considered for removal. + // 4. Extract merge nodes together that have the same CondState and + // AncestorState from the innermost to the outermost into IfOps; + // Note: In the above only nodes that feed into a merge node will be + // considered for functionalization. // Perform a DFS over the graph and // * Determine the reverse topological order of the nodes (there should be no @@ -1306,50 +1312,46 @@ Status FunctionalizeCond::FunctionalizeInternal() { return Status::OK(); } - TF_RETURN_IF_ERROR(DetermineCondStates(std::move(rev_topo_order))); - - if (VLOG_IS_ON(4)) DumpGraphWithCondState("cond_id"); + TF_RETURN_IF_ERROR(DetermineStates(std::move(rev_topo_order))); + if (VLOG_IS_ON(4)) DumpGraphWithCondState("id"); // Sort the merge nodes from innermost outwards. SortMergeNodes(&merge_order); - // Extract from innermost out. - for (auto it = merge_order.begin(); it != merge_order.end(); ++it) { - Node* merge = *it; - auto id = cond_state_map_.LookupId(merge); - if (cond_state_map_.IsDead(id)) continue; - - // Construct a Conditional with the predicate of the merge (which is the - // last entry of the CondState for the merge) and this as parent. - DCHECK(id->back().predicate.node != nullptr); - Conditional cond(id->back().predicate, this, &cond_state_map_); - TF_RETURN_IF_ERROR(cond.AddMerge(merge)); - - // Find all merge nodes with the same CondId. This is done repeatedly as - // the CondId can change due replaced conditionals. E.g., the one branch - // could previously have had a conditional nested in it, and so would have - // had CondState with sub-state [switch(p,b),m] (where p is some predicate), - // post removing the nested conditional that sub-state would no longer be - // path of the propagated state along that path. - auto end = merge_order.end(); - for (auto merge_candidate_it = std::next(it); merge_candidate_it != end; - ++merge_candidate_it) { - auto merge_candidate_it_id = - cond_state_map_.LookupId(*merge_candidate_it); - if (merge_candidate_it_id != id) continue; - TF_RETURN_IF_ERROR(cond.AddMerge(*merge_candidate_it)); + // Cluster merge nodes by CondId and AncestorId in order of nesting. + using ClusterPair = std::pair; + std::deque> merge_clusters; + std::map merge_cluster_index; + for (Node* merge : merge_order) { + auto cond_id = state_map_.LookupCondId(merge); + if (state_map_.IsDead(cond_id)) continue; + + ClusterPair key = + std::make_pair(cond_id, state_map_.LookupAncestorId(merge)); + auto idx = merge_cluster_index.find(key); + if (idx == merge_cluster_index.end()) { + merge_cluster_index[key] = merge_clusters.size(); + merge_clusters.push_back({merge}); + } else { + merge_clusters[idx->second].emplace_back(merge); } + } + // Extract the conditionals from inner most to outer most. Extracting from + // innermost to outermost enables the extraction pass to stop once it + // encounters a Switch node instead of having to keep track of Switch/Merge + // nodes seen. + for (const auto& cluster : merge_clusters) { + // Construct a Conditional with the predicate of the merge. + Conditional cond(merge_to_predicate_.at(cluster.front()), this, + &state_map_); + for (Node* merge : cluster) TF_RETURN_IF_ERROR(cond.AddMerge(merge)); TF_RETURN_IF_ERROR(cond.BuildAndReplace(graph_, library_)); if (VLOG_IS_ON(4)) DumpGraphWithCondState("after_extract"); } - // All remaining Switch nodes are not reachable from a Merge node and - // removed. This is to account for dead Switch nodes. - for (int s_id : switch_ids) delete_nodes_.push_back(s_id); - for (Node* m : merge_order) delete_nodes_.push_back(m->id()); - DeleteReachableNodes(); + DeleteReachableAndDeadNodes(switch_ids, merge_order); return Status::OK(); } @@ -1359,11 +1361,14 @@ void FunctionalizeCond::DumpGraphWithCondState(const string& name) { for (Node* n : graph_->nodes()) { n->ClearAttr(kCondGroupDebugAttr); - n->AddAttr(kCondGroupDebugAttr, cond_state_map_.CondStateToString(n)); + n->AddAttr(kCondGroupDebugAttr, + absl::StrCat(state_map_.CondStateToString(n), "_", + state_map_.AncestorStateToString(n))); } LOG(INFO) << "FunctionalizeControlFlow (" << name << "): " << dump_graph::DumpGraphToFile( - strings::StrCat("functionalize_", name), *graph_, library_); + absl::StrCat("functionalize_cond_", name), *graph_, + library_); } Status FunctionalizeCond::Functionalize(Graph* graph, diff --git a/tensorflow/compiler/tf2xla/functionalize_cond.h b/tensorflow/compiler/tf2xla/functionalize_cond.h index 86436011c6ebdc608a5811a1b0d6a10015d405bd..189980894073b1da1a12d1c284536336eb920900 100644 --- a/tensorflow/compiler/tf2xla/functionalize_cond.h +++ b/tensorflow/compiler/tf2xla/functionalize_cond.h @@ -43,59 +43,53 @@ enum class BranchType { kNeither = 3, }; -// CondStateMap is responsible for mapping from each graph Node to a CondState, -// where each CondState is the array of CondNodes (corresponding to switch, -// merge or dead states) as described below. For efficiency, this class interns -// the CondState, so that CondState equality comparisons are simply pointer +// StateMap is responsible for mapping from each graph Node to +// * a CondState, where each CondState is a map from predicate to branch (i,e., +// what predicates have to hold or not hold). +// * a AncestorState, where each AncestorState is a set of switch/merge nodes +// that are an ancestor of the node in the graph; +// For efficiency, this class interns the CondState (AncestorState), so that +// CondState (AncestorState) equality comparisons are simply pointer // comparisons. -class CondStateMap { +class StateMap { public: - explicit CondStateMap(Graph* graph); - - // Represents an entry in the CondState. An entry can either be the - // switch (along with predicate), merge, or dead: - // * switch node indicates a node that is executed along a branch with the - // given predicate - a branch can be then, else or both; - // * merge node indicates that the node is executed as output of a merge; - // * dead indicates that this node can never be executed; - struct CondNode { - enum class Type { kSwitch = 1, kMerge = 2, kDead = 3 }; - - CondNode(Type type, Node* switch_node = nullptr, - BranchType branch = BranchType::kNeither); - - string ToString() const; - bool operator==(const CondNode& other) const; - bool operator!=(const CondNode& other) const; - - // Type of node. - Type type; - - // Predicate and branch, only used when type is kSwitch. - OutputTensor predicate; - BranchType branch; + explicit StateMap(Graph* graph); + + // Compare two OutputTensors by (node id, index). + struct OutputTensorLess { + bool operator()(const OutputTensor& lhs, const OutputTensor& rhs) const; }; - // A node in the graph is executed when multiple conditions hold. The order - // represents the nesting of the predicates that hold and is used when - // extracting the nested conditionals. - using CondState = std::vector; + // A node in the graph is executed when multiple conditions hold. Keep track + // of the predicates that must hold for a node to execute. + using CondState = std::map; // Every unique ID is mapped to a CondState. using CondId = const CondState*; + // Keep track of which switch/merge node's feed into a node's values. + using AncestorState = std::set; + + // Every unique ID is mapped to a AncestorState. + using AncestorId = const AncestorState*; + // Returns the CondId for a given node. - CondId LookupId(const Node* node) const; + CondId LookupCondId(const Node* node) const; // Returns the unique CondId for CondState. - CondId GetUniqueId(const CondState& state); - - // Returns the CondState for a Node. - // REQUIRES: node has a non-empty CondState. - const CondState& LookupState(const Node* node) const; + CondId GetCondId(const CondState& state); // Resets the CondId for a given node. - void ResetId(const Node* node, CondId id); + void ResetCondId(const Node* node, CondId id); + + // Returns the AncestorId for a given node. + AncestorId LookupAncestorId(const Node* node) const; + + // Returns the unique AncestorId for CondState. + AncestorId GetAncestorId(const AncestorState& state); + + // Resets the AncestorId for a given node. + void ResetAncestorId(const Node* node, AncestorId id); // Marks `node` as dead. void MarkDead(const Node* node); @@ -103,45 +97,30 @@ class CondStateMap { // Determine branch execution of CondState. BranchType FindBranchOf(CondId id, OutputTensor predicate) const; - // Enum to represent whether one cond flow state contains another. - enum ContainsResult { - kIncomparable, - kEqual, - kLhsContainsRhs, - kRhsContainsLhs - }; - - // Returns whether the lhs CondState holds wherever rhs CondState hols. I.e., - // [(p,t)] contains [(p,t), (r,t)]. - ContainsResult LhsHoldsWhereverRhsHolds(CondId lhs, CondId rhs); - // Returns textual representation of node's CondState. string CondStateToString(const Node* node) const; string CondStateToString(CondId id) const; + // Returns textual representation of node's AncestorState. + string AncestorStateToString(const Node* node) const; + // Returns whether the cond state is the dead state. bool IsDead(CondId id) const; // Returns whether the cond state is the empty state. bool IsEmpty(CondId id) const; - // Computes the predicates that have to hold for a node to execute and returns - // whether it was possible to determine the predicates that must hold. `scope` - // is populated with these predicates. Scope differs from state in that it - // does not include merge and both nodes. - bool ScopeIn(CondId id, CondId* scope); - private: - // Hash for CondNode and CondState. - struct CondHash { - size_t operator()(const CondNode& item) const; - size_t operator()(const CondState& vec) const; + // Hash for CondState and AncestorState. + struct Hash { + size_t operator()(const CondState& map) const; + size_t operator()(const AncestorState& map) const; }; // Set to keep track of unique CondStates. // Pointers to the entries in the unordered set are used as identifiers: // unordered_set guarantees that the pointers remain the same. - std::unordered_set condstate_set_; + std::unordered_set condstate_set_; // Mapping from Node id to CondId. std::vector node_to_condid_map_; @@ -150,7 +129,12 @@ class CondStateMap { // from Node id in the original graph to the CondId, but there will be nodes // added to the original graph (such as If nodes) whose CondState needs to be // tracked too. - std::unordered_map added_node_mapping_; + std::unordered_map added_node_condid_mapping_; + + // AncestorId variants of the CondId members. + std::unordered_set ancestorstate_set_; + std::vector node_to_ancestorid_map_; + std::unordered_map added_node_ancestorid_mapping_; // Identifier of the dead flow state. The empty flow state is represented with // a nullptr. @@ -173,7 +157,8 @@ class FunctionalizeCond { // Add a If node to the graph defined by def that will, amongst other, replace // replacee in the graph. - xla::StatusOr AddIfNode(const NodeDef& def, const Node* replacee); + xla::StatusOr AddIfNode(const NodeDef& def, const Node* replacee, + const OutputTensor& predicate); // Propagates the state of a newly inserted node. Status PropagateUpdatedState(const Node* replacee); @@ -185,35 +170,42 @@ class FunctionalizeCond { FunctionalizeCond(Graph* graph, FunctionLibraryDefinition* library); // Performs the actual cond functionalization. Iterate over groups of merge - // nodes (linked by common predicate & CondIds of the incomming edges), - // from innermost to outermost, and extract into If nodes. + // nodes (linked by common predicates & ancestor IDs), from innermost to + // outermost, and extract into If nodes. Status FunctionalizeInternal(); // Returns the forward flow state propagated along edge `e`. - // This may modify cond_state_map_. - CondStateMap::CondId StateAlongEdge(const Edge* e); + // This may modify state_map_. + StateMap::CondId StateAlongEdge(const Edge* e); - // Determines the CondState of all the nodes in the given vector where - // the input is expected in reverse topological order. - // This populates the cond_state_map_. - Status DetermineCondStates(std::vector rev_topo_order); + // Determines the CondState and AncestorState of all the nodes in the given + // vector where the input is expected in reverse topological order. + // This populates the state_map_. + Status DetermineStates(std::vector rev_topo_order); // Determine the CondState for a given node using the incomming edges // to the node. Note: it is expected that this node's CondState is only // determined once its input's CondState is. - Status DetermineCondState(Node* dst); + Status DetermineCondState(Node* dst) { + if (IsMerge(dst)) return DetermineCondStateMerge(dst); + return DetermineCondStateNonMerge(dst); + } // Helper functions for DetermineCondState. + Status DetermineCondStateNonMerge(Node* dst); Status DetermineCondStateMerge(Node* dst); - // Helper functions for DetermineCondStates. Determines the dst node's - // CondState by joining the src and dst's CondState where either - // the dst node is a merge or not. - // These may modify cond_state_map_. - xla::StatusOr JoinCondStatesMerge( - CondStateMap::CondId src, CondStateMap::CondId dst); - xla::StatusOr JoinCondStatesNonMerge( - CondStateMap::CondId src, CondStateMap::CondId dst); + // Determines the dst node's CondState by joining the src and dst's CondState + // where either the dst node is a merge or not. + // These may modify state_map_. + xla::StatusOr JoinCondStatesMerge(Node* merge, + StateMap::CondId src, + StateMap::CondId dst); + xla::StatusOr JoinCondStatesNonMerge(StateMap::CondId src, + StateMap::CondId dst); + + // Determines which switch/merge nodes are ancestors of this node. + Status DetermineAncestorState(Node* dst); // Checks if a merge node is redundant and if so removes it from the graph. Status RemoveRedundantMerge(Node* node); @@ -225,15 +217,18 @@ class FunctionalizeCond { // nesting depth. void SortMergeNodes(std::vector* merge_order); - // Deletes all nodes in/consumers of `delete_nodes_`. - void DeleteReachableNodes(); + // Deletes all nodes in/consumers reachable from switch/merge nodes that were + // extracted. + void DeleteReachableAndDeadNodes(const std::vector& switch_ids, + const std::vector& merge_order); - // Member used to unique the CondState to a unique CondId and keep track of - // CondState/CondId per Node. - CondStateMap cond_state_map_; + // Member used to unique the CondState to a unique CondId (AncestorState to a + // unique AncestorId) and keep track of CondState/CondId + // (AncestorState/AncestorId) per Node. + StateMap state_map_; - // Nodes to be deleted. - std::deque delete_nodes_; + // Mapping from merge nodes to predicate. + std::unordered_map merge_to_predicate_; FunctionLibraryDefinition* library_; Graph* graph_; diff --git a/tensorflow/compiler/tf2xla/functionalize_cond_test.cc b/tensorflow/compiler/tf2xla/functionalize_cond_test.cc index a27f8893925855f536801a8a68855b82ac07462d..b0aabd63bbda784b3b7103a438ce025eea0cd93b 100644 --- a/tensorflow/compiler/tf2xla/functionalize_cond_test.cc +++ b/tensorflow/compiler/tf2xla/functionalize_cond_test.cc @@ -37,28 +37,23 @@ class FunctionalizeCondTest : public ::testing::Test { flib_def_.get())); } - CondStateMap::CondId GetUniqueId( - const CondStateMap::CondStateMap::CondState& state) { - return fc_->cond_state_map_.GetUniqueId(state); + StateMap::CondId GetUniqueId(const StateMap::StateMap::CondState& state) { + return fc_->state_map_.GetCondId(state); } - xla::StatusOr JoinCondStatesNonMerge( - CondStateMap::CondId src, CondStateMap::CondId dst) { - return fc_->JoinCondStatesNonMerge(src, dst); - } - - xla::StatusOr JoinCondStatesMerge( - CondStateMap::CondId src, CondStateMap::CondId dst) { - return fc_->JoinCondStatesMerge(src, dst); + string GetString(const StateMap::StateMap::CondId id) { + return fc_->state_map_.CondStateToString(id); } - bool ScopeIn(CondStateMap::CondId ff, CondStateMap::CondId* scope) { - return fc_->cond_state_map_.ScopeIn(ff, scope); + xla::StatusOr JoinCondStatesNonMerge(StateMap::CondId src, + StateMap::CondId dst) { + return fc_->JoinCondStatesNonMerge(src, dst); } - CondStateMap::ContainsResult LhsHoldsWhereverRhsHolds( - CondStateMap::CondId lhs, CondStateMap::CondId rhs) { - return fc_->cond_state_map_.LhsHoldsWhereverRhsHolds(lhs, rhs); + xla::StatusOr JoinCondStatesMerge(Node* n, + StateMap::CondId src, + StateMap::CondId dst) { + return fc_->JoinCondStatesMerge(n, src, dst); } FunctionDefLibrary fdef_lib_; @@ -69,50 +64,6 @@ class FunctionalizeCondTest : public ::testing::Test { namespace { -TEST_F(FunctionalizeCondTest, ScopeIn) { - Tensor pred_tensor(DT_BOOL, TensorShape()); - pred_tensor.flat().setZero(); - Node* pred = test::graph::Constant(graph_.get(), pred_tensor, "pred"); - Tensor val_tensor(DT_INT32, TensorShape()); - val_tensor.flat().setZero(); - Node* val = test::graph::Constant(graph_.get(), val_tensor, "val"); - Node* s = test::graph::Switch(graph_.get(), val, pred); - - { - CondStateMap::CondStateMap::CondState ss; - ss.emplace_back(CondStateMap::CondNode( - CondStateMap::CondNode::Type::kSwitch, s, BranchType::kThenBranch)); - CondStateMap::CondId id = GetUniqueId(ss); - CondStateMap::CondId scope; - ASSERT_TRUE(ScopeIn(id, &scope)); - ASSERT_TRUE(id == scope); - } - - CondStateMap::CondState empty; - { - CondStateMap::CondState ss; - ss.emplace_back(CondStateMap::CondNode( - CondStateMap::CondNode::Type::kSwitch, s, BranchType::kBoth)); - ss.emplace_back( - CondStateMap::CondNode(CondStateMap::CondNode::Type::kMerge)); - CondStateMap::CondId id = GetUniqueId(ss); - CondStateMap::CondId scope_1; - ASSERT_TRUE(ScopeIn(id, &scope_1)); - ASSERT_TRUE(scope_1 == GetUniqueId(empty)); - ASSERT_TRUE(id != scope_1); - - ss.clear(); - ss.emplace_back(CondStateMap::CondNode( - CondStateMap::CondNode::Type::kSwitch, s, BranchType::kBoth)); - id = GetUniqueId(ss); - CondStateMap::CondId scope_2; - ASSERT_TRUE(ScopeIn(id, &scope_2)); - - ASSERT_TRUE(LhsHoldsWhereverRhsHolds(scope_1, scope_2) == - CondStateMap::ContainsResult::kLhsContainsRhs); - } -} - TEST_F(FunctionalizeCondTest, JoinCondStates) { Tensor pred_tensor(DT_BOOL, TensorShape()); pred_tensor.flat().setZero(); @@ -120,22 +71,18 @@ TEST_F(FunctionalizeCondTest, JoinCondStates) { Tensor val_tensor(DT_INT32, TensorShape()); val_tensor.flat().setZero(); Node* val = test::graph::Constant(graph_.get(), val_tensor, "val"); - Node* s = test::graph::Switch(graph_.get(), val, pred); + Node* m = test::graph::Merge(graph_.get(), val, val); - CondStateMap::CondId empty = GetUniqueId({}); - - CondStateMap::CondId then_branch; + StateMap::CondId then_branch; { - CondStateMap::CondState ss; - ss.emplace_back(CondStateMap::CondNode( - CondStateMap::CondNode::Type::kSwitch, s, BranchType::kThenBranch)); + StateMap::CondState ss; + ss.insert(std::make_pair(OutputTensor(pred, 0), BranchType::kThenBranch)); then_branch = GetUniqueId(ss); } - CondStateMap::CondId else_branch; + StateMap::CondId else_branch; { - CondStateMap::CondState ss; - ss.emplace_back(CondStateMap::CondNode( - CondStateMap::CondNode::Type::kSwitch, s, BranchType::kElseBranch)); + StateMap::CondState ss; + ss.insert(std::make_pair(OutputTensor(pred, 0), BranchType::kElseBranch)); else_branch = GetUniqueId(ss); } @@ -144,39 +91,14 @@ TEST_F(FunctionalizeCondTest, JoinCondStates) { EXPECT_TRUE(errors::IsInvalidArgument(status)); // Merge between then and else branch. - auto joined_or = JoinCondStatesMerge(then_branch, else_branch); + auto joined_or = JoinCondStatesMerge(m, then_branch, else_branch); TF_EXPECT_OK(joined_or.status()); - CondStateMap::CondId joined = joined_or.ValueOrDie(); + StateMap::CondId joined = joined_or.ValueOrDie(); // Merge between then branch and both branch. auto t = JoinCondStatesNonMerge(then_branch, joined); // Note: this is OK in terms of constraint predication, but TF_EXPECT_OK(t.status()); - - // Post merge the propagated forward flow state has an additional merge. - CondStateMap::CondId post_merge; - { - CondStateMap::CondState ss; - ss = *joined; - ss.emplace_back( - CondStateMap::CondNode(CondStateMap::CondNode::Type::kMerge)); - post_merge = GetUniqueId(ss); - } - - t = JoinCondStatesNonMerge(post_merge, joined); - TF_EXPECT_OK(t.status()); - EXPECT_TRUE(joined == t.ValueOrDie()); - - // No predicate that results in two paths predicated on different conditions - // merge. - t = JoinCondStatesMerge(post_merge, joined); - EXPECT_FALSE(t.ok()); - - // Post the merge we are effectively in the root scope and merging should - // result in the more restrictive post merge state. - t = JoinCondStatesNonMerge(post_merge, empty); - TF_EXPECT_OK(t.status()); - EXPECT_TRUE(post_merge == t.ValueOrDie()); } } // namespace diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc index 5932be4e525dec11a8f3c59bb85e0449e76e79c0..f818d80022da0bad851c896f2714c15b20b22195 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc @@ -31,11 +31,18 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/common_runtime/graph_optimizer.h" +#include "tensorflow/core/common_runtime/process_function_library_runtime.h" #include "tensorflow/core/framework/graph_to_functiondef.h" #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/control_flow.h" +#include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/cleanup.h" +#include "tensorflow/core/public/session_options.h" +#include "tensorflow/core/public/version.h" namespace tensorflow { @@ -68,4 +75,198 @@ Status FunctionalizeControlFlow(Graph* graph, return FunctionalizeControlFlow(/*lookup_library=*/nullptr, graph, library); } +Status FunctionalizeControlFlowForFunction( + const string& func_name, const string& new_func_name, + const protobuf::Map& attrs, + FunctionLibraryDefinition* fld, FunctionLibraryRuntime* flr, + std::map>* canonicalized_name_to_new_name, + bool* modified) { + *modified = false; + + // Convert the function to Graph. + FunctionLibraryRuntime::Handle handle; + TF_RETURN_IF_ERROR(flr->Instantiate(func_name, AttrSlice(&attrs), &handle)); + Status ret_status = Status::OK(); + auto cleanup_handle = gtl::MakeCleanup([&]() { + auto s = flr->ReleaseHandle(handle); + if (!s.ok()) { + ret_status.Update(s); + } + }); + const FunctionBody* body = flr->GetFunctionBody(handle); + Graph* g = body->graph; + + // Check if the graph has Switch or Merge node. + bool has_switch_or_merge = false; + for (Node* n : body->graph->nodes()) { + if (n->type_string() == "Switch" || n->type_string() == "Merge") { + has_switch_or_merge = true; + break; + } + } + // We cannot return here directly if the graph has no Switch/Merge. + // It might contain function call nodes, or If/While nodes with Switch/Merge + // in function body. We still need to rewrite those functions and modify + // corresponding nodes. + + // If any node has associated functions, functionalize them first. + // Gather nodes with associated functions first, because rewriting those nodes + // might involve node deletion/addition. Avoid modifying nodes while iterating + // it. + std::vector>> + nodes_to_associated_functions; + for (auto* n : g->nodes()) { + auto associated_functions = GetAssociatedFunctions(*n, fld); + if (!associated_functions.empty()) { + nodes_to_associated_functions.push_back({n, associated_functions}); + } + } + for (auto iter : nodes_to_associated_functions) { + Node* n = iter.first; + auto associated_functions = iter.second; + for (auto& associated_function : associated_functions) { + string name = associated_function.func_name(); + string canonicalized_name = + Canonicalize(name, AttrSlice(&associated_function.attrs())); + auto iter = canonicalized_name_to_new_name->find(canonicalized_name); + string new_name; + bool function_modified; + if (iter != canonicalized_name_to_new_name->end()) { + // If we already processed this function, check if it was rewritten. If + // the function was rewritten, the entry will be non-empty. Otherwise + // the entry will be empty. + function_modified = iter->second.has_value(); + if (function_modified) { + new_name = iter->second.value(); + } + } else { + if (associated_function.type() == + AssociatedFunctionInfo::AssociatedFunctionType::kSymbolicGradient) { + // For SymbolicGradient, `name` is always "SymbolicGradient", + // which is not very informative. Use node name instead. + new_name = fld->UniqueFunctionName(absl::StrCat(n->name(), "_f15n_")); + } else { + new_name = fld->UniqueFunctionName(absl::StrCat(name, "_f15n_")); + } + TF_RETURN_IF_ERROR(FunctionalizeControlFlowForFunction( + name, new_name, associated_function.attrs(), fld, flr, + canonicalized_name_to_new_name, &function_modified)); + if (function_modified) { + // If the function was rewritten, add an non-empty entry. So later we + // know we have processed this function, and it was rewritten into + // another function. + (*canonicalized_name_to_new_name)[canonicalized_name] = new_name; + } else { + // If the function was not rewritten, add an empty entry. So later + // we know we have processed this function, and it does not need to be + // rewritten. + (*canonicalized_name_to_new_name)[canonicalized_name] = absl::nullopt; + } + } + if (function_modified) { + *modified = true; + + // Notice that if "n" is a function call, RewriteAssociatedFunction() + // will delete it and create a new node instead, making "n" an invalid + // pointer. That's fine because in that case, associated_functions will + // only have one member and the loop will only run once. + TF_RETURN_IF_ERROR(RewriteAssociatedFunction( + g, n, fld, associated_function, new_name)); + } + } + } + + if (has_switch_or_merge) { + *modified = true; + + // Functionalize the function body. + if (VLOG_IS_ON(4)) { + dump_graph::DumpGraphToFile( + absl::StrCat("functionalize_control_flow_before_fdef_", func_name), + *g, fld); + } + TF_RETURN_IF_ERROR(FunctionalizeControlFlow(g, fld)); + if (VLOG_IS_ON(4)) { + dump_graph::DumpGraphToFile( + absl::StrCat("functionalize_control_flow_after_fdef_", func_name), *g, + fld); + } + } + + if (*modified) { + // Add rewritten FunctionDef into library. + FunctionDef functionalized_fdef; + TF_RETURN_IF_ERROR( + GraphToFunctionDef(*g, new_func_name, &functionalized_fdef)); + if (func_name == new_func_name) { + VLOG(2) << "Replacing function " << func_name; + TF_RETURN_IF_ERROR( + fld->ReplaceFunction(new_func_name, functionalized_fdef)); + } else { + VLOG(2) << "Adding function " << new_func_name; + TF_RETURN_IF_ERROR(fld->AddFunctionDef(functionalized_fdef)); + } + } + + return ret_status; +} + +Status FunctionalizeControlFlowPass::Run( + const GraphOptimizationPassOptions& options) { + Graph* graph = options.graph->get(); + if (VLOG_IS_ON(4)) { + dump_graph::DumpGraphToFile("functionalize_control_flow_before", *graph, + options.flib_def); + } + std::unique_ptr pflr( + new ProcessFunctionLibraryRuntime( + /*device_mgr=*/nullptr, options.session_options->env, + TF_GRAPH_DEF_VERSION, options.flib_def, OptimizerOptions())); + FunctionLibraryRuntime* flr = + pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice); + + // Find XLA compile ops and its corresponding FunctionDef. + // TPUCompile op is not in the map because graph rewriting might happen + // multiple times, and we want to avoid functionalize it again. + static std::map* kNodeTypeToFunctionAttrMapping = + new std::map{ + // TPUReplicate ops are generated by EncapsulateTPUComputationsPass. + {"TPUReplicate", "computation"}, + // XlaLaunch ops are generated by EncapsulateXlaComputationsPass. + {"XlaLaunch", "function"}, + }; + std::map> canonicalized_name_to_new_name; + for (Node* n : graph->nodes()) { + auto it = kNodeTypeToFunctionAttrMapping->find(n->type_string()); + if (it == kNodeTypeToFunctionAttrMapping->end()) { + continue; + } + const string func_attr = it->second; + if (kNodeTypeToFunctionAttrMapping->find(n->type_string()) != + kNodeTypeToFunctionAttrMapping->end()) { + NameAttrList func; + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), func_attr, &func)); + VLOG(2) << "Graph has node " << n->type_string() + << ". Corresponding function: " << func.name(); + string new_func_name = options.flib_def->UniqueFunctionName( + absl::StrCat(func.name(), "_f15n_")); + bool modified; + TF_RETURN_IF_ERROR(FunctionalizeControlFlowForFunction( + func.name(), new_func_name, func.attr(), options.flib_def, flr, + &canonicalized_name_to_new_name, &modified)); + if (modified) { + n->ClearAttr(func_attr); + func.set_name(new_func_name); + n->AddAttr(func_attr, func); + } + } + } + + if (VLOG_IS_ON(4)) { + dump_graph::DumpGraphToFile("functionalize_control_flow_after", *graph, + options.flib_def); + } + return Status::OK(); +} + } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.h b/tensorflow/compiler/tf2xla/functionalize_control_flow.h index 55600f2a8b5302cef26b9be4ccd0f8804476a17a..ba99205640ccdc83a3a4d50e3ec474907894a835 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.h +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_CONTROL_FLOW_H_ #include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/core/common_runtime/optimization_registry.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/graph/graph.h" @@ -32,6 +33,14 @@ Status FunctionalizeControlFlow(const FunctionLibraryDefinition* lookup_library, Graph* graph, FunctionLibraryDefinition* library); +// This pass looks at the graph and all associated FunctionDefs, and turns +// traditional control flow structure (Switch/Merge/etc.) into functional +// control flow structure (If/While). +class FunctionalizeControlFlowPass : public GraphOptimizationPass { + public: + Status Run(const GraphOptimizationPassOptions& options) override; +}; + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_CONTROL_FLOW_H_ diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_pass_registration.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow_pass_registration.cc new file mode 100644 index 0000000000000000000000000000000000000000..a10a9d0499457bbc0383ea3a8c678f153e21894b --- /dev/null +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_pass_registration.cc @@ -0,0 +1,25 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/functionalize_control_flow.h" + +namespace tensorflow { + +// This pass is required for some AOT backends and all JIT backends, so this +// file exists as a separate lib and will be linked to both AOT and JIT. +REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 27, + FunctionalizeControlFlowPass); + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc index cc52057f214a45a861660c3d34cbbffd9c45a640..c3841f996f801e855da75b23f01d41674ec51c4d 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/ops/control_flow_ops_internal.h" #include "tensorflow/cc/ops/function_ops.h" +#include "tensorflow/cc/ops/functional_ops.h" #include "tensorflow/cc/ops/resource_variable_ops.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/compiler/tf2xla/cc/ops/xla_ops.h" @@ -112,16 +113,12 @@ TEST(FunctionalizeControlFlow, Conditional) { auto y = ops::Placeholder(scope.WithOpName("y"), DT_INT32); auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32); auto less = ops::Less(scope.WithOpName("cond/Less"), y, x); - auto if_op = ops::XlaIf(scope.WithOpName(op_name), less, - std::initializer_list{less, y, x}, then_fn, - else_fn, {DT_INT32}); + auto if_op = ops::If(scope.WithOpName(op_name), less, + std::initializer_list{less, y, x}, {DT_INT32}, + then_fn, else_fn); auto id = ops::Identity(scope.WithOpName("cond/Merge"), if_op.output[0]); GraphDef expected; TF_EXPECT_OK(scope.ToGraphDef(&expected)); - // TODO(jpienaar): Create wrapper for IfOp. - for (NodeDef& n : *expected.mutable_node()) { - if (n.op() == "XlaIf") n.set_op("If"); - } TF_EXPECT_GRAPH_EQ(expected, graph_def); } @@ -177,7 +174,7 @@ TEST(FunctionalizeControlFlow, Conditional) { Status FindWhileCondAndBody(const GraphDef& graph, NameAttrList* cond, NameAttrList* body) { for (const NodeDef& node : graph.node()) { - if (node.op() == "XlaWhile") { + if (node.op() == "While") { const NameAttrList* result; TF_RETURN_IF_ERROR(GetNodeAttr(node, "cond", &result)); *cond = *result; @@ -186,7 +183,7 @@ Status FindWhileCondAndBody(const GraphDef& graph, NameAttrList* cond, return Status::OK(); } } - return errors::NotFound("No XlaWhile node found in graph"); + return errors::NotFound("No While node found in graph"); } // Graph: @@ -255,8 +252,8 @@ TEST(FunctionalizeControlFlow, OneLoopVar) { Scope scope = Scope::NewRootScope().ExitOnError(); auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32); auto while_op = - ops::XlaWhile(scope.WithOpName("while/LoopCond"), - std::initializer_list{source}, cond_fn, body_fn); + ops::While(scope.WithOpName("while/LoopCond"), + std::initializer_list{source}, cond_fn, body_fn); auto sink = ops::Identity(scope.WithOpName("sink"), while_op[0]); GraphDef expected; TF_EXPECT_OK(scope.ToGraphDef(&expected)); @@ -392,8 +389,8 @@ TEST(FunctionalizeControlFlow, NoinlineLoopBody) { Scope scope = Scope::NewRootScope().ExitOnError(); auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32); auto while_op = - ops::XlaWhile(scope.WithOpName("while/LoopCond"), - std::initializer_list{source}, cond_fn, body_fn); + ops::While(scope.WithOpName("while/LoopCond"), + std::initializer_list{source}, cond_fn, body_fn); GraphDef expected; TF_ASSERT_OK(scope.ToGraphDef(&expected)); TF_EXPECT_GRAPH_EQ(expected, graph_def); @@ -483,8 +480,8 @@ TEST(FunctionalizeControlFlow, OneLoopVarWithoutExit) { Scope scope = Scope::NewRootScope().ExitOnError(); auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32); auto while_op = - ops::XlaWhile(scope.WithOpName("while/LoopCond"), - std::initializer_list{source}, cond_fn, body_fn); + ops::While(scope.WithOpName("while/LoopCond"), + std::initializer_list{source}, cond_fn, body_fn); GraphDef expected; TF_EXPECT_OK(scope.ToGraphDef(&expected)); TF_EXPECT_GRAPH_EQ(expected, graph_def); @@ -625,8 +622,8 @@ TEST(FunctionalizeControlFlow, TwoLoopVars) { auto x = ops::Placeholder(scope.WithOpName("Placeholder/x"), DT_INT32); auto y = ops::Placeholder(scope.WithOpName("Placeholder/y"), DT_INT32); auto while_op = - ops::XlaWhile(scope.WithOpName("while/LoopCond"), - std::initializer_list{x, y}, cond_fn, body_fn); + ops::While(scope.WithOpName("while/LoopCond"), + std::initializer_list{x, y}, cond_fn, body_fn); auto sink_x = ops::Identity(scope.WithOpName("sink_x"), while_op[0]); auto sink_y = ops::Identity(scope.WithOpName("sink_y"), while_op[1]); GraphDef expected; @@ -805,11 +802,11 @@ TEST(FunctionalizeControlFlow, Complex) { auto assign = ops::AssignAddVariableOp( scope.WithOpName("outer/inner/assign_add"), enter_var, add_jkx); - auto one = - ops::Const(scope.WithOpName("outer/inner/One") - .WithControlDependencies( - gtl::ArraySlice{assign.operation}), - 1); + auto one = ops::Const( + scope.WithOpName("outer/inner/One") + .WithControlDependencies( + absl::Span{assign.operation}), + 1); auto add_j = ops::Add(scope.WithOpName("outer/inner/add_j"), identity_j, one); @@ -823,7 +820,7 @@ TEST(FunctionalizeControlFlow, Complex) { scope.WithOpName("outer/add/y").WithControlDependencies(identity_i), 1); auto add_i = ops::Add(scope.WithOpName("outer/add") - .WithControlDependencies(gtl::ArraySlice{ + .WithControlDependencies(absl::Span{ exit_j.output.op(), exit_k.output.op()}), identity_i, one_outer); auto next_iteration_i = @@ -864,9 +861,9 @@ TEST(FunctionalizeControlFlow, Complex) { auto zero = ops::Const(scope.WithOpName("outer/Const"), 0); - auto while_op = ops::XlaWhile(scope.WithOpName("outer/LoopCond"), - std::initializer_list{zero, y, x, var}, - outer_cond_fn, outer_body_fn); + auto while_op = ops::While(scope.WithOpName("outer/LoopCond"), + std::initializer_list{zero, y, x, var}, + outer_cond_fn, outer_body_fn); auto sink = ops::Identity(scope.WithOpName("sink"), while_op[0]); GraphDef expected; TF_EXPECT_OK(scope.ToGraphDef(&expected)); @@ -921,15 +918,15 @@ TEST(FunctionalizeControlFlow, Complex) { auto one_j = ops::Const( scope.WithOpName("outer/j").WithControlDependencies(identity_i), 1); auto while_op = - ops::XlaWhile(scope.WithOpName("outer/LoopCond_1"), - std::initializer_list{one_j, arg1, arg2, arg3}, - inner_cond_fn, inner_body_fn); + ops::While(scope.WithOpName("outer/LoopCond_1"), + std::initializer_list{one_j, arg1, arg2, arg3}, + inner_cond_fn, inner_body_fn); auto one_outer = ops::Const( scope.WithOpName("outer/add/y").WithControlDependencies(identity_i), 1); auto add_i = ops::Add(scope.WithOpName("outer/add") - .WithControlDependencies(gtl::ArraySlice{ + .WithControlDependencies(absl::Span{ while_op[0].op(), while_op[1].op()}), identity_i, one_outer); @@ -991,11 +988,11 @@ TEST(FunctionalizeControlFlow, Complex) { auto assign = ops::AssignAddVariableOp( scope.WithOpName("outer/inner/assign_add"), arg3, add_jkx); - auto one = - ops::Const(scope.WithOpName("outer/inner/One") - .WithControlDependencies( - gtl::ArraySlice{assign.operation}), - 1); + auto one = ops::Const( + scope.WithOpName("outer/inner/One") + .WithControlDependencies( + absl::Span{assign.operation}), + 1); auto add_j = ops::Add(scope.WithOpName("outer/inner/add_j"), identity_j, one); diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_util.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow_util.cc index 924fcdd9cd72a6472e0b2748680f2552fa65ec79..54cebc61778ba051b9c903f8e2c3696cec69843a 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow_util.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_util.cc @@ -42,7 +42,7 @@ xla::StatusOr BuildRetvalNode(Graph* graph, DataType type, int index) { const char* const kRetValOp = "_Retval"; NodeDef ret_def; ret_def.set_op(kRetValOp); - ret_def.set_name(strings::StrCat(kRetValOp, index)); + ret_def.set_name(absl::StrCat(kRetValOp, index)); AddNodeAttr("T", type, &ret_def); AddNodeAttr("index", index, &ret_def); return AddNodeDefToGraph(ret_def, graph); diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_util.h b/tensorflow/compiler/tf2xla/functionalize_control_flow_util.h index 61940e3586c59ffc660eaac8f8d035fbbbdfeffd..582b49d5116acc651fb6242b5c2b9aeeac269532 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow_util.h +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_util.h @@ -43,13 +43,12 @@ xla::StatusOr BuildRetvalNode(Graph* graph, DataType type, int index); // Returns a textual representation of the names of the nodes in the input. template string NodesToString(const T& nodes) { - return strings::StrCat("{", - absl::StrJoin(nodes, ",", - [](string* output, const Node* node) { - strings::StrAppend(output, - node->name()); - }), - "}"); + return absl::StrCat("{", + absl::StrJoin(nodes, ",", + [](string* output, const Node* node) { + absl::StrAppend(output, node->name()); + }), + "}"); } } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/functionalize_while.cc b/tensorflow/compiler/tf2xla/functionalize_while.cc index 6e3c4b0e0f695f0073f2c8aa1a4b342e39ea4be5..d87436a7b4ac37c74d0f0df921779c8716290013 100644 --- a/tensorflow/compiler/tf2xla/functionalize_while.cc +++ b/tensorflow/compiler/tf2xla/functionalize_while.cc @@ -25,6 +25,7 @@ limitations under the License. #include "absl/types/optional.h" #include "tensorflow/compiler/jit/union_find.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" +#include "tensorflow/compiler/tf2xla/functionalize_cond.h" #include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -34,6 +35,7 @@ limitations under the License. #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/control_flow.h" #include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/lib/strings/strcat.h" namespace tensorflow { namespace { @@ -132,7 +134,7 @@ Status CopySubgraph(const Graph& graph, const Frame* frame, StatusOr BuildArgNode(Graph* graph, DataType type, int index) { const char* const kArgOp = "_Arg"; NodeDef arg_def; - NodeDefBuilder builder(strings::StrCat(kArgOp, index), kArgOp); + NodeDefBuilder builder(absl::StrCat(kArgOp, index), kArgOp); builder.Attr("T", type); builder.Attr("index", index); TF_RETURN_IF_ERROR(builder.Finalize(&arg_def)); @@ -473,12 +475,19 @@ Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library, } } - // Builds the condition and body functions. + // Builds the condition and body functions. Notice that we call + // FunctionalizeCond() on cond_graph and body_graph because we might have + // unfunctionalized "if" in cond_graph and body_graph. Functionalize them + // before they are encapsulated in FunctionDef. std::unique_ptr cond_graph; TF_RETURN_IF_ERROR(BuildLoopCondition(*graph, frame, &cond_graph)); + FixupSourceAndSinkEdges(cond_graph.get()); + TF_RETURN_IF_ERROR(FunctionalizeCond(cond_graph.get(), library)); DataTypeVector arg_types; std::unique_ptr body_graph; TF_RETURN_IF_ERROR(BuildLoopBody(*graph, frame, &arg_types, &body_graph)); + FixupSourceAndSinkEdges(body_graph.get()); + TF_RETURN_IF_ERROR(FunctionalizeCond(body_graph.get(), library)); VLOG(2) << "Frame " << frame->name << " condition: " << dump_graph::DumpGraphToFile("loop_condition", *cond_graph, library) @@ -487,9 +496,9 @@ Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library, static std::atomic sequence_num(0LL); int64 id = ++sequence_num; NameAttrList cond_name; - cond_name.set_name(strings::StrCat("_functionalize_cond_", id)); + cond_name.set_name(absl::StrCat("_functionalize_cond_", id)); NameAttrList body_name; - body_name.set_name(strings::StrCat("_functionalize_body_", id)); + body_name.set_name(absl::StrCat("_functionalize_body_", id)); FunctionDef cond_fdef; TF_RETURN_IF_ERROR( GraphToFunctionDef(*cond_graph, cond_name.name(), &cond_fdef)); @@ -510,10 +519,16 @@ Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library, // Builds a While operator. NodeDef while_def; - NodeDefBuilder builder(frame->loop_cond->name(), "XlaWhile"); + NodeDefBuilder builder(frame->loop_cond->name(), "While", library); builder.Attr("T", arg_types); builder.Attr("cond", cond_name); builder.Attr("body", body_name); + string outside_compilation; + if (GetNodeAttr(frame->loop_cond->def(), kXlaOutsideCompilationAttrName, + &outside_compilation) + .ok()) { + builder.Attr(kXlaOutsideCompilationAttrName, outside_compilation); + } std::vector inputs; for (int i = 0; i < frame->args.size(); ++i) { const Arg& arg = frame->args[i]; @@ -653,9 +668,9 @@ Status FunctionalizeWhileLoop(const FunctionLibraryDefinition* lookup_library, // There should be no cycle at this point, since while loops have been removed // from graph. - // Check that the newly added XlaWhile nodes don't feed into themselves. + // Check that the newly added While nodes don't feed into themselves. for (const Node* node : graph->op_nodes()) { - if (node->def().op() == "XlaWhile") { + if (node->def().op() == "While") { TF_RETURN_WITH_CONTEXT_IF_ERROR( CheckNodeNotInCycle(node, graph->num_node_ids()), "Functionalizing loop failed."); diff --git a/tensorflow/compiler/tf2xla/graph_compiler.cc b/tensorflow/compiler/tf2xla/graph_compiler.cc index ba37ed33370f04ff51ff4c448673be61905faccf..c019a28e892ff89f559ddbec2360d6caa9c1808f 100644 --- a/tensorflow/compiler/tf2xla/graph_compiler.cc +++ b/tensorflow/compiler/tf2xla/graph_compiler.cc @@ -20,7 +20,6 @@ limitations under the License. #include #include "tensorflow/compiler/tf2xla/const_analysis.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" -#include "tensorflow/compiler/tf2xla/functionalize_control_flow.h" #include "tensorflow/compiler/tf2xla/literal_util.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" @@ -81,7 +80,7 @@ Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph, TF_ASSIGN_OR_RETURN(auto literal, client->ComputeConstant(constant_graph)); TF_RETURN_IF_ERROR( - LiteralToHostTensor(*literal, arg.type, &arg.constant_value)); + LiteralToHostTensor(literal, arg.type, &arg.constant_value)); } else { arg.kind = XlaCompiler::Argument::kParameter; } @@ -127,7 +126,7 @@ Status GraphCompiler::Compile() { TF_RET_CHECK(!n->IsRecv() && !n->IsSend() && !n->IsSwitch()) << "Not supported node: " << n->DebugString(); params.op_kernel = op_kernel.get(); - gtl::InlinedVector output_attr(n->num_outputs()); + absl::InlinedVector output_attr(n->num_outputs()); params.output_attr_array = output_attr.data(); // tensor_inputs_ is a buffer reused across graph traversal. We clean up and @@ -146,6 +145,7 @@ Status GraphCompiler::Compile() { } OpKernelContext op_context(¶ms, n->num_outputs()); + VLOG(3) << "Translating " << params.op_kernel->name(); if (IsFunctional(n)) { TF_RETURN_IF_ERROR(CompileFunctionalNode(n, &op_context)); } else { diff --git a/tensorflow/compiler/tf2xla/graph_compiler.h b/tensorflow/compiler/tf2xla/graph_compiler.h index 127562eb23d775f17179cc9ee968ec2255cf3a14..e9f02201cf6bed5495dff7dff76c5bafe7771516 100644 --- a/tensorflow/compiler/tf2xla/graph_compiler.h +++ b/tensorflow/compiler/tf2xla/graph_compiler.h @@ -55,17 +55,17 @@ namespace tensorflow { // op registration infrastructure instead of FunctionLibraryRuntime. class GraphCompiler { public: - GraphCompiler(XlaContext* xla_context, XlaCompilationDevice* device, - Graph* graph, FunctionLibraryRuntime* flib, + GraphCompiler(XlaCompilationDevice* device, Graph* graph, + FunctionLibraryRuntime* flib, ScopedStepContainer* step_container) - : xla_context_(xla_context), - device_(device), + : device_(device), graph_(graph), flib_(flib), step_container_(step_container) {} - // Compiles the graph. The results are written in `xla_context` that is passed - // into the compiler. + // Compiles the graph. The results are written in xla_context stored in the + // resource_manager of the 'XlaCompilationDevice' that's passed into the + // constructor. Status Compile(); private: @@ -82,14 +82,13 @@ class GraphCompiler { // using `compiler_`. Status CompileFunctionalNode(Node* n, OpKernelContext* op_context); - XlaContext* xla_context_; XlaCompilationDevice* device_; Graph* graph_; FunctionLibraryRuntime* flib_; ScopedStepContainer* step_container_; // A buffer to hold tensor inputs to a node, this is reused across the graph // traversal. - gtl::InlinedVector tensor_inputs_; + absl::InlinedVector tensor_inputs_; }; } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index c1438f893f6d3c46dd7f6c39b6aa3367a79789f0..224e5ea123b4905bcfe0947722dbaf4a703f9893 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -62,6 +62,7 @@ tf_kernel_library( "one_hot_op.cc", "pack_op.cc", "pad_op.cc", + "permute_op.cc", "pooling_ops.cc", "qr_op.cc", "quantize_and_dequantize_op.cc", @@ -94,6 +95,7 @@ tf_kernel_library( "stateless_random_ops.cc", "strided_slice_op.cc", "tensor_array_ops.cc", + "tensor_list_ops.cc", "tile_ops.cc", "topk_op.cc", "training_ops.cc", @@ -113,13 +115,13 @@ tf_kernel_library( "shape_util.h", ], deps = [ + ":conv_op_helpers", ":if_op", ":while_op", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/strings", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/lib:batch_dot", + "//tensorflow/compiler/tf2xla/lib:broadcast", "//tensorflow/compiler/tf2xla/lib:cholesky", "//tensorflow/compiler/tf2xla/lib:qr", "//tensorflow/compiler/tf2xla/lib:random", @@ -158,6 +160,7 @@ tf_kernel_library( "//tensorflow/core/kernels:control_flow_ops", "//tensorflow/core/kernels:conv_ops", "//tensorflow/core/kernels:cwise_op", + "//tensorflow/core/kernels:list_kernels", "//tensorflow/core/kernels:no_op", "//tensorflow/core/kernels:ops_util", "//tensorflow/core/kernels:pooling_ops", @@ -167,14 +170,32 @@ tf_kernel_library( "//tensorflow/core/kernels:sparse_to_dense_op", "//tensorflow/core/kernels:stack_ops", "//tensorflow/core/kernels:training_ops", - ] + if_mkl( - [ - "//tensorflow/core/kernels:mkl_transpose_op", - ], - [ - "//tensorflow/core/kernels:transpose_op", - ], - ), + "//tensorflow/core/kernels:transpose_op", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "conv_op_helpers", + srcs = ["conv_op_helpers.cc"], + hdrs = ["conv_op_helpers.h"], + deps = [ + "//tensorflow/compiler/tf2xla:common", + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client/lib:arithmetic", + "//tensorflow/compiler/xla/client/lib:constants", + "//tensorflow/compiler/xla/client/lib:numeric", + "//tensorflow/core:framework", + "//tensorflow/core/kernels:bounds_check", + "//tensorflow/core/kernels:conv_ops", + "//tensorflow/core/kernels:ops_util", + "@com_google_absl//absl/types:span", + ], ) tf_kernel_library( @@ -183,6 +204,7 @@ tf_kernel_library( hdrs = ["while_op.h"], deps = [ "//tensorflow/compiler/tf2xla:common", + "//tensorflow/compiler/tf2xla:side_effect_util", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/compiler/xla:literal", @@ -200,6 +222,7 @@ tf_kernel_library( hdrs = ["if_op.h"], deps = [ "//tensorflow/compiler/tf2xla:common", + "//tensorflow/compiler/tf2xla:side_effect_util", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/compiler/xla:literal", diff --git a/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc b/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc index b3ad0aea84eef601de08909f760699b8700d28f4..a267c0c72fce67d7c22c55a57f8d5ac4ffd2b7e2 100644 --- a/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc @@ -34,12 +34,6 @@ class FusedBatchNormOp : public XlaOpKernel { OP_REQUIRES( ctx, FormatFromString(data_format_str, &data_format_), errors::InvalidArgument("Invalid data format: ", data_format_str)); - OP_REQUIRES(ctx, - (data_format_ == FORMAT_NHWC || data_format_ == FORMAT_NCHW || - data_format_ == FORMAT_HWNC || data_format_ == FORMAT_HWCN), - errors::InvalidArgument( - "Unsupported data format ", ToString(data_format_), - "; supported formats are NHWC, NCHW, HWNC and HWCN")); } void Compile(XlaOpKernelContext* ctx) override { @@ -110,12 +104,6 @@ class FusedBatchNormGradOp : public XlaOpKernel { OP_REQUIRES( ctx, FormatFromString(data_format_str, &data_format_), errors::InvalidArgument("Invalid data format: ", data_format_str)); - OP_REQUIRES(ctx, - (data_format_ == FORMAT_NHWC || data_format_ == FORMAT_NCHW || - data_format_ == FORMAT_HWNC || data_format_ == FORMAT_HWCN), - errors::InvalidArgument( - "Unsupported data format ", ToString(data_format_), - "; supported formats are NHWC, NCHW, HWNC and HWCN")); } void Compile(XlaOpKernelContext* ctx) override { diff --git a/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc b/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc index 48f2a005ab16651fe29d0f6f9d881f95693da461..a18e04995b5e1e0b0374f7b0edd6f5e114cf994a 100644 --- a/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc @@ -23,10 +23,10 @@ namespace { void BatchToSpace(XlaOpKernelContext* ctx, const xla::XlaOp& input, DataType input_dtype, const TensorShape& input_tensor_shape, - gtl::ArraySlice block_shape, + absl::Span block_shape, const xla::Literal& crops) { const int input_rank = input_tensor_shape.dims(); - const gtl::InlinedVector input_shape = + const absl::InlinedVector input_shape = input_tensor_shape.dim_sizes(); const int block_rank = block_shape.size(); @@ -34,7 +34,7 @@ void BatchToSpace(XlaOpKernelContext* ctx, const xla::XlaOp& input, ctx, input_rank >= 1 + block_rank, errors::InvalidArgument("input rank should be >= ", 1 + block_rank, " instead of ", input_rank)); - gtl::ArraySlice remainder_shape(input_shape); + absl::Span remainder_shape(input_shape); remainder_shape.remove_prefix(1 + block_rank); OP_REQUIRES( diff --git a/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc b/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc index 2e383b1473590403823863f89264e5381d8e8806..182f7c99344845964f7010127718f876ab6e8a44 100644 --- a/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc @@ -39,7 +39,7 @@ class BCastArgsOp : public XlaOpKernel { OP_REQUIRES( ctx, ctx->num_inputs() == 2, errors::Unimplemented("Broadcast for n-ary operations (n > 2)")); - gtl::InlinedVector shapes; + absl::InlinedVector shapes; for (int i = 0; i < ctx->num_inputs(); ++i) { const TensorShape in_shape = ctx->InputShape(i); OP_REQUIRES(ctx, TensorShapeUtils::IsVector(in_shape), @@ -88,7 +88,7 @@ class BCastGradArgsOp : public XlaOpKernel { ctx, ctx->num_inputs() == 2, errors::Unimplemented("Broadcast for n-ary operations (n > 2)")); - gtl::InlinedVector shapes; + absl::InlinedVector shapes; for (int i = 0; i < ctx->num_inputs(); ++i) { const TensorShape in_shape = ctx->InputShape(i); OP_REQUIRES(ctx, TensorShapeUtils::IsVector(in_shape), diff --git a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc index 2c328102e0bd84709707f102272691b6aec9a577..47e517a6576d3a848bc41ceb703df2bd778c4a35 100644 --- a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" @@ -30,21 +31,21 @@ namespace { // A subclass of a XlaBinaryOp must build the computation that // describes the (tensor,tensor)->tensor function to apply to each element of // the input. -#define XLA_MAKE_BINARY(NAME, HLO) \ - class NAME##Op : public XlaBinaryOp { \ - public: \ - explicit NAME##Op(OpKernelConstruction* ctx) : XlaBinaryOp(ctx) {} \ - xla::XlaOp Computation( \ - XlaOpKernelContext* ctx, const xla::XlaOp& lhs, \ - const gtl::ArraySlice& lhs_shape, const xla::XlaOp& rhs, \ - const gtl::ArraySlice& rhs_shape, \ - const BCast& broadcast_helper, \ - const std::vector& extend_dimensions) override { \ - xla::XlaBuilder* b = ctx->builder(); \ - (void)b; \ - return HLO; \ - } \ - }; \ +#define XLA_MAKE_BINARY(NAME, HLO) \ + class NAME##Op : public XlaBinaryOp { \ + public: \ + explicit NAME##Op(OpKernelConstruction* ctx) : XlaBinaryOp(ctx) {} \ + xla::XlaOp Computation( \ + XlaOpKernelContext* ctx, const xla::XlaOp& lhs, \ + const absl::Span& lhs_shape, const xla::XlaOp& rhs, \ + const absl::Span& rhs_shape, \ + const BCast& broadcast_helper, \ + const std::vector& extend_dimensions) override { \ + xla::XlaBuilder* b = ctx->builder(); \ + (void)b; \ + return HLO; \ + } \ + }; \ REGISTER_XLA_OP(Name(#NAME), NAME##Op) XLA_MAKE_BINARY(Add, xla::Add(lhs, rhs, extend_dimensions)); @@ -55,6 +56,24 @@ XLA_MAKE_BINARY(Div, xla::Div(lhs, rhs, extend_dimensions)); XLA_MAKE_BINARY(Atan2, xla::Atan2(lhs, rhs, extend_dimensions)); XLA_MAKE_BINARY(Complex, xla::Complex(lhs, rhs, extend_dimensions)); +// Implementation of DivNoNan. Pseudo-code: +// if (y == 0) { +// return 0 +// } else { +// return x / y; +// } +static xla::XlaOp DivNoNanImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x, + xla::XlaOp y, const BCast& broadcast_helper) { + std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper); + auto zero = XlaHelpers::Zero(b, dtype); + auto y_equals_0 = xla::Eq(y, zero); + auto zeros = xla::ZerosLike(x); + auto result = xla::Select(y_equals_0, zeros, xla::Div(x, y)); + return result; +} +XLA_MAKE_BINARY(DivNoNan, + DivNoNanImpl(b, input_type(0), lhs, rhs, broadcast_helper)); + // Implementation of FloorDiv. Pseudo-code: // if ((x < 0) != (y < 0)) { // T abs_x = std::abs(x); @@ -65,7 +84,10 @@ XLA_MAKE_BINARY(Complex, xla::Complex(lhs, rhs, extend_dimensions)); // } static xla::XlaOp FloorDivImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x, xla::XlaOp y, const BCast& broadcast_helper) { - std::tie(x, y) = XlaBinaryOp::Broadcast(b, x, y, broadcast_helper); + std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper); + if (DataTypeIsUnsigned(dtype)) { + return xla::Div(x, y); + } auto zero = XlaHelpers::Zero(b, dtype); auto one = XlaHelpers::One(b, dtype); auto different_sign = xla::Ne(xla::Lt(x, zero), xla::Lt(y, zero)); @@ -81,12 +103,30 @@ static xla::XlaOp FloorDivImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x, XLA_MAKE_BINARY(FloorDiv, FloorDivImpl(b, input_type(0), lhs, rhs, broadcast_helper)); +static xla::XlaOp XlogyImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x, + xla::XlaOp y, const BCast& broadcast_helper) { + std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper); + auto zero = XlaHelpers::Zero(b, dtype); + auto is_zero = xla::Eq(x, zero); + return xla::Select(is_zero, zero, xla::Mul(x, xla::Log(y))); +} +XLA_MAKE_BINARY(Xlogy, XlogyImpl(b, input_type(0), lhs, rhs, broadcast_helper)); + +static xla::XlaOp XdivyImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x, + xla::XlaOp y, const BCast& broadcast_helper) { + std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper); + auto zero = XlaHelpers::Zero(b, dtype); + auto is_zero = xla::Eq(x, zero); + return xla::Select(is_zero, zero, xla::Div(x, y)); +} +XLA_MAKE_BINARY(Xdivy, XdivyImpl(b, input_type(0), lhs, rhs, broadcast_helper)); + // Implementation of FloorMod. Pseudo-code: // T trunc_mod = std::fmod(x, y); // return (x < T(0)) == (y < T(0)) ? trunc_mod : std::fmod(trunc_mod + y, y); static xla::XlaOp FloorModImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x, xla::XlaOp y, const BCast& broadcast_helper) { - std::tie(x, y) = XlaBinaryOp::Broadcast(b, x, y, broadcast_helper); + std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper); auto zero = XlaHelpers::Zero(b, dtype); auto same_sign = xla::Eq(xla::Lt(x, zero), xla::Lt(y, zero)); auto trunc_mod = xla::Rem(x, y); diff --git a/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc b/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc index 4bd7c74dca2a7cbb51f2a329ac575d635f314516..9bb11fb67e3e4ddc48d68631c60f96c60b921094 100644 --- a/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc @@ -13,16 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "absl/algorithm/container.h" -#include "tensorflow/compiler/tf2xla/shape_util.h" -#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/lib/broadcast.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/lib/constants.h" -#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/bcast.h" namespace tensorflow { namespace { @@ -37,60 +32,9 @@ class BroadcastToOp : public XlaOpKernel { TensorShape output_shape; OP_REQUIRES_OK(context, context->ConstantInputAsShape(1, &output_shape)); - OP_REQUIRES(context, input_shape.dims() <= output_shape.dims(), - errors::InvalidArgument( - "Input rank (", input_shape.dims(), - ") must be less than or equal to the output rank (", - output_shape.dims(), ")")); - - auto input_dims = input_shape.dim_sizes(); - auto output_dims = output_shape.dim_sizes(); - - // Broadcasting is done right-to-left on right-aligned dimensions; reverse - // the two vectors so elements to be broadcast are aligned. - absl::c_reverse(input_dims); - absl::c_reverse(output_dims); - - std::vector broadcast_dims; - std::vector broadcast_shape; - for (int i = 0; i < output_shape.dims(); ++i) { - if (i < input_shape.dims()) { - OP_REQUIRES( - context, - (output_dims[i] == 0 && input_dims[i] == 0) || - (input_dims[i] != 0 && output_dims[i] % input_dims[i] == 0), - errors::InvalidArgument("invalid shape to broadcast from ", - input_shape.DebugString(), " to ", - output_shape.DebugString())); - - broadcast_dims.push_back(broadcast_shape.size()); - if (output_dims[i] == input_dims[i] || input_dims[i] == 1) { - broadcast_shape.push_back(output_dims[i]); - } - if (output_dims[i] != input_dims[i]) { - // Add dimensions [I, O/I], which we will later flatten to just - // [O]. We must do this in two phases since XLA broadcasting does not - // support tiling. - broadcast_shape.push_back(input_dims[i]); - broadcast_shape.push_back(output_dims[i] / input_dims[i]); - } - } else { - broadcast_shape.push_back(output_dims[i]); - } - } - absl::c_reverse(broadcast_dims); - int broadcast_shape_size = broadcast_shape.size(); - for (int64& broadcast_dim : broadcast_dims) { - broadcast_dim = broadcast_shape_size - broadcast_dim - 1; - } - absl::c_reverse(broadcast_shape); - xla::XlaOp output = xla::Reshape( - xla::BroadcastInDim(context->Input(0), - xla::ShapeUtil::MakeShape( - context->input_xla_type(0), broadcast_shape), - broadcast_dims), - output_shape.dim_sizes()); - context->SetOutput(0, output); + auto output = BroadcastTo(context->Input(0), output_shape.dim_sizes()); + OP_REQUIRES_OK(context, output.status()); + context->SetOutput(0, output.ValueOrDie()); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/concat_op.cc b/tensorflow/compiler/tf2xla/kernels/concat_op.cc index f4106051043859a6786705009d76b02a64cd3ff1..0ae23aa6dfe49048ac5cb8ae00c12432b2e2a2fe 100644 --- a/tensorflow/compiler/tf2xla/kernels/concat_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/concat_op.cc @@ -37,6 +37,16 @@ limitations under the License. namespace tensorflow { namespace { +// Used to determine the number of Tensors allowed in a Concat op to prevent +// going over the max gpu parameter memory size. This is an issue because concat +// is variadic and can have an unlimited number of arguments when called. +// Concat ops with more Tensors than this will be split into multiple concat +// ops. +// +// TODO(b/112613927): Remove the logic here and put it properly in an HLO pass +// along with boxing large numbers of parameters. +constexpr int64 kMaxConcatArgsPerOp = 500; + // -------------------------------------------------------------------------- class ConcatBaseOp : public XlaOpKernel { public: @@ -74,6 +84,7 @@ class ConcatBaseOp : public XlaOpKernel { // Make a vector holding the XlaOp for each of the inputs that has non-zero // elements. std::vector input_data; + std::vector partial_concats; int output_concat_dim = 0; const bool input_is_scalar = IsLegacyScalar(input_shape); for (int i = 0; i < N; ++i) { @@ -94,10 +105,30 @@ class ConcatBaseOp : public XlaOpKernel { input_data.push_back(handle); } output_concat_dim += in_shape.dims() > 0 ? in_shape.dim_size(axis) : 1; + + // Concat is associative, so it can be split into many operations when too + // many arguments are in a single op. This is a temporary workaround for + // b/112613927 where too many parameters in an XlaLaunchOp later result in + // too many parameters to a single GPU kernel. + if (i && i % kMaxConcatArgsPerOp == 0) { + partial_concats.push_back( + xla::ConcatInDim(ctx->builder(), input_data, axis)); + input_data.clear(); + } } + // Add any inputs that have not been put into another concat yet. + partial_concats.insert(partial_concats.end(), input_data.begin(), + input_data.end()); VLOG(1) << "Concat dim " << concat_dim << " equivalent to " << axis; - ctx->SetOutput(0, xla::ConcatInDim(ctx->builder(), input_data, axis)); + // Don't add an additional "identity" concatenate for better readibility of + // IR. + if (partial_concats.size() == 1) { + ctx->SetOutput(0, partial_concats.front()); + } else { + ctx->SetOutput(0, + xla::ConcatInDim(ctx->builder(), partial_concats, axis)); + } } private: diff --git a/tensorflow/compiler/tf2xla/kernels/const_op.cc b/tensorflow/compiler/tf2xla/kernels/const_op.cc index da8cf3fc6fa694f592280f8c249d317827d9cd09..2628ef8e2454976aeff3859fa5dc1d8e106f32e1 100644 --- a/tensorflow/compiler/tf2xla/kernels/const_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/const_op.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/framework/types.pb.h" namespace tensorflow { namespace { @@ -76,6 +77,17 @@ class ConstOp : public XlaOpKernel { return; } break; + case DT_COMPLEX64: + if (proto_.scomplex_val_size() == 2) { + ctx->SetOutput( + 0, + xla::Broadcast(xla::ConstantR0( + b, xla::complex64(proto_.scomplex_val(0), + proto_.scomplex_val(1))), + shape.dim_sizes())); + return; + } + break; case DT_INT32: if (proto_.int_val_size() == 1) { ctx->SetOutput( diff --git a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc new file mode 100644 index 0000000000000000000000000000000000000000..c9a1be494066e4f935a1d818bc86c86333e34fae --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc @@ -0,0 +1,509 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// XLA-specific Ops for 2D convolution. + +#include "tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h" +#include "absl/types/span.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/numeric.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/numeric_op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_slice.h" +#include "tensorflow/core/kernels/bounds_check.h" +#include "tensorflow/core/kernels/conv_grad_ops.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/util/padding.h" +#include "tensorflow/core/util/tensor_format.h" + +namespace tensorflow { +namespace { + +// Returns the expanded size of a filter used for depthwise convolution. +// If `shape` is [H, W, ..., M, N] returns [H, W, ..., M, M*N]. +xla::Shape ExpandedFilterShapeForDepthwiseConvolution(const xla::Shape& shape) { + int num_dims = shape.dimensions_size(); + CHECK_GE(num_dims, 2); // Crash OK + xla::Shape expanded_shape = shape; + expanded_shape.set_dimensions( + num_dims - 1, + shape.dimensions(num_dims - 2) * shape.dimensions(num_dims - 1)); + return expanded_shape; +} + +// Create a mask for depthwise convolution that will make a normal convolution +// produce the same results as a depthwise convolution. For a [2, 2, 3, 2] +// depthwise filter this returns a [2, 2, 3, 6] tensor +// 1 1 0 0 0 0 1 1 0 0 0 0 +// 0 0 1 1 0 0 0 0 1 1 0 0 +// 0 0 0 0 1 1 0 0 0 0 1 1 +// +// 1 1 0 0 0 0 1 1 0 0 0 0 +// 0 0 1 1 0 0 0 0 1 1 0 0 +// 0 0 0 0 1 1 0 0 0 0 1 1 +// +// The first step is to create a one tensor, A, that is [3] +// 0 1 2 +// +// and another tensor, B, that is [3 * 2] +// 0 1 2 3 4 5 +// +// and divide B it by 2 to get +// 0 0 1 1 2 2 +// +// then we broadcast the B to [2, 2, 3, 3 * 2] +// 0 0 1 1 2 2 0 0 1 1 2 2 +// 0 0 1 1 2 2 0 0 1 1 2 2 +// 0 0 1 1 2 2 0 0 1 1 2 2 +// +// 0 0 1 1 2 2 0 0 1 1 2 2 +// 0 0 1 1 2 2 0 0 1 1 2 2 +// 0 0 1 1 2 2 0 0 1 1 2 2 +// +// Finally compare A and broadcasted B in dimension 2 amd return the result at +// the beginning of the comment. +xla::XlaOp CreateExpandedFilterMask(const xla::Shape& filter_shape, + xla::XlaBuilder* builder) { + xla::Shape expanded_filter_shape = + ExpandedFilterShapeForDepthwiseConvolution(filter_shape); + int64 depthwise_multiplier = + filter_shape.dimensions(filter_shape.dimensions_size() - 1); + int64 input_feature = + filter_shape.dimensions(filter_shape.dimensions_size() - 2); + + // Create a M sized linspace and an M*N sized linspace that will be + // broadcasted into perpendicular dimensions and compared. + xla::XlaOp input_feature_iota = xla::Iota(builder, xla::S32, input_feature); + xla::XlaOp expanded_feature_iota = + xla::Iota(builder, xla::S32, input_feature * depthwise_multiplier); + + // Divide the M*N sized linspace by the depthwise_multiplier to create + // [0 0 1 1 2 2] in the example in the function comment. + expanded_feature_iota = + xla::Div(expanded_feature_iota, + XlaHelpers::IntegerLiteral(builder, DataType::DT_INT32, + depthwise_multiplier)); + + // Broadcast the N*M linspace to [H, W, ..., M, M*N]. + std::vector expanded_feature_broadcast_dims( + expanded_filter_shape.dimensions().begin(), + expanded_filter_shape.dimensions().end()); + expanded_feature_broadcast_dims.pop_back(); + auto broadcasted_expanded_feature_iota = + xla::Broadcast(expanded_feature_iota, expanded_feature_broadcast_dims); + + // Compare the broadcasted linspace to the input feature linspace in the + // input feature dimension to create a diagonal predicate. + return xla::Eq(broadcasted_expanded_feature_iota, input_feature_iota, + {expanded_filter_shape.dimensions_size() - 2}); +} + +// Reshapes a filter of shape [H, W, ..., M, N] to [H, W, ..., 1, M*N]. Used to +// build a depthwise convolution. +xla::XlaOp ReshapeFilterForDepthwiseConvolution(const xla::Shape& filter_shape, + const xla::XlaOp& filter) { + int64 input_feature_dim = filter_shape.dimensions_size() - 2; + int64 output_feature_dim = filter_shape.dimensions_size() - 1; + int64 depthwise_multiplier = filter_shape.dimensions(output_feature_dim); + int64 input_feature = filter_shape.dimensions(input_feature_dim); + + // Create a [H, W, ..., 1, N*M] reshape of the filter. + xla::Shape implicit_broadcast_filter_shape = filter_shape; + implicit_broadcast_filter_shape.set_dimensions(input_feature_dim, 1); + implicit_broadcast_filter_shape.set_dimensions( + output_feature_dim, depthwise_multiplier * input_feature); + return xla::Reshape( + filter, xla::AsInt64Slice(implicit_broadcast_filter_shape.dimensions())); +} + +// Reduces the results of the convolution with an expanded filter to the +// non-expanded filter. +xla::XlaOp ContractFilterForDepthwiseBackprop(const xla::Shape& filter_shape, + const xla::XlaOp& filter_backprop, + xla::XlaBuilder* builder) { + auto masked_expanded_filter = + xla::Select(CreateExpandedFilterMask(filter_shape, builder), + filter_backprop, xla::ZerosLike(filter_backprop)); + + auto elem_type = filter_shape.element_type(); + return xla::Reshape( + // This reduce does not need inputs to be converted with + // XlaHelpers::SumAccumulationType() since the select above guarantees + // that only one element is non zero, so there cannot be accumulated + // precision error. + xla::Reduce(masked_expanded_filter, xla::Zero(builder, elem_type), + CreateScalarAddComputation(elem_type, builder), + {filter_shape.dimensions_size() - 2}), + xla::AsInt64Slice(filter_shape.dimensions())); +} + +// Performs some basic checks on ConvOpAttrs that are true for all kinds of XLA +// convolutions (as currently implemented). +Status CheckConvAttrs(const ConvOpAttrs& attrs) { + const int num_dims = attrs.num_spatial_dims + 2; + if (attrs.strides.size() != num_dims) { + return errors::InvalidArgument("Sliding window strides field must specify ", + num_dims, " dimensions"); + } + int batch_dim = GetTensorBatchDimIndex(num_dims, attrs.data_format); + int feature_dim = GetTensorFeatureDimIndex(num_dims, attrs.data_format); + if (attrs.strides[batch_dim] != 1 || attrs.strides[feature_dim] != 1) { + return errors::Unimplemented( + "Current implementation does not yet support strides in the batch and " + "depth dimensions."); + } + if (attrs.dilations.size() != num_dims) { + return errors::InvalidArgument("Dilations field must specify ", num_dims, + " dimensions"); + } + if (attrs.dilations[batch_dim] != 1 || attrs.dilations[feature_dim] != 1) { + return errors::Unimplemented( + "Current implementation does not support dilations in the batch and " + "depth dimensions."); + } + for (int i = 0; i < attrs.num_spatial_dims; ++i) { + int input_dim = GetTensorSpatialDimIndex(num_dims, attrs.data_format, i); + if (attrs.dilations[input_dim] < 1) { + return errors::Unimplemented("Dilation values must be positive; ", i, + "th spatial dimension had dilation ", + attrs.dilations[input_dim]); + } + } + return Status::OK(); +} + +// Wrapper around ConvBackpropComputeDimensions that converts from XLA shapes +// to TensorShapes. +Status ConvBackpropComputeDimensionsV2XlaShapes( + StringPiece label, int num_spatial_dims, const xla::Shape& input_shape, + const xla::Shape& filter_shape, const xla::Shape& out_backprop_shape, + absl::Span dilations, const std::vector& strides, + Padding padding, TensorFormat data_format, ConvBackpropDimensions* dims) { + TensorShape input_tensor_shape, filter_tensor_shape, + out_backprop_tensor_shape; + TF_RETURN_IF_ERROR(XLAShapeToTensorShape(input_shape, &input_tensor_shape)); + TF_RETURN_IF_ERROR(XLAShapeToTensorShape(filter_shape, &filter_tensor_shape)); + TF_RETURN_IF_ERROR( + XLAShapeToTensorShape(out_backprop_shape, &out_backprop_tensor_shape)); + return ConvBackpropComputeDimensionsV2( + label, num_spatial_dims, input_tensor_shape, filter_tensor_shape, + out_backprop_tensor_shape, dilations, strides, padding, data_format, + dims); +} + +} // anonymous namespace + +xla::StatusOr ConvOpAttrs::Create(int num_spatial_dims, + bool depthwise, + OpKernelConstruction* ctx) { + ConvOpAttrs attrs; + attrs.num_spatial_dims = num_spatial_dims; + attrs.depthwise = depthwise; + TF_RETURN_IF_ERROR(ctx->GetAttr("dilations", &attrs.dilations)); + TF_RETURN_IF_ERROR(ctx->GetAttr("strides", &attrs.strides)); + TF_RETURN_IF_ERROR(ctx->GetAttr("padding", &attrs.padding)); + + string data_format; + TF_RETURN_IF_ERROR(ctx->GetAttr("data_format", &data_format)); + if (!FormatFromString(data_format, &attrs.data_format)) { + return errors::InvalidArgument("Invalid data format: ", data_format); + } + + return attrs; +} + +xla::StatusOr MakeXlaForwardConvOp(StringPiece /*type_string*/, + xla::XlaOp conv_input, + xla::XlaOp filter, + const ConvOpAttrs& attrs) { + TF_RETURN_IF_ERROR(CheckConvAttrs(attrs)); + + auto* builder = conv_input.builder(); + TF_ASSIGN_OR_RETURN(xla::Shape input_shape, builder->GetShape(conv_input)); + // Filter has the form [filter_rows, filter_cols, ..., in_depth, out_depth] + TF_ASSIGN_OR_RETURN(xla::Shape filter_shape, builder->GetShape(filter)); + + // For 2D convolution, there should be 4 dimensions. + int num_dims = attrs.num_spatial_dims + 2; + if (input_shape.dimensions_size() != num_dims) { + return errors::InvalidArgument("input must be ", num_dims, "-dimensional", + input_shape.DebugString()); + } + if (filter_shape.dimensions_size() != num_dims) { + return errors::InvalidArgument( + "filter must be ", num_dims, + "-dimensional: ", filter_shape.DebugString()); + } + + // The last two dimensions of the filter are the input and output shapes. + int batch_dim = GetTensorBatchDimIndex(num_dims, attrs.data_format); + int feature_dim = GetTensorFeatureDimIndex(num_dims, attrs.data_format); + + int64 in_depth = filter_shape.dimensions(attrs.num_spatial_dims); + // The 'C' dimension for input is in_depth. It must be the same as + // the filter's in_depth. + if (in_depth != input_shape.dimensions(feature_dim)) { + return errors::InvalidArgument( + "input and filter must have the same depth: ", in_depth, " vs ", + input_shape.dimensions(feature_dim)); + } + + if (attrs.depthwise) { + filter = ReshapeFilterForDepthwiseConvolution(filter_shape, filter); + } + + xla::ConvolutionDimensionNumbers dims; + std::vector window_strides(attrs.num_spatial_dims); + std::vector lhs_dilation(attrs.num_spatial_dims, 1); + std::vector rhs_dilation(attrs.num_spatial_dims); + std::vector> padding(attrs.num_spatial_dims); + + dims.set_input_batch_dimension(batch_dim); + dims.set_output_batch_dimension(batch_dim); + dims.set_input_feature_dimension(feature_dim); + dims.set_output_feature_dimension(feature_dim); + dims.set_kernel_input_feature_dimension(attrs.num_spatial_dims); + dims.set_kernel_output_feature_dimension(attrs.num_spatial_dims + 1); + + for (int i = 0; i < attrs.num_spatial_dims; ++i) { + const int64 dim = GetTensorSpatialDimIndex(num_dims, attrs.data_format, i); + dims.add_input_spatial_dimensions(dim); + dims.add_kernel_spatial_dimensions(i); + dims.add_output_spatial_dimensions(dim); + window_strides[i] = attrs.strides.at(dim); + rhs_dilation[i] = attrs.dilations.at(dim); + + int64 unused_output_size; + TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerboseV2( + input_shape.dimensions(dim), filter_shape.dimensions(i), + rhs_dilation[i], window_strides[i], attrs.padding, &unused_output_size, + &padding[i].first, &padding[i].second)); + } + + return xla::ConvGeneralDilated( + conv_input, filter, window_strides, padding, lhs_dilation, rhs_dilation, + dims, /*feature_group_count=*/attrs.depthwise ? in_depth : 1); +} + +xla::StatusOr MakeXlaBackpropInputConvOp( + StringPiece type_string, const xla::Shape& input_shape, xla::XlaOp filter, + xla::XlaOp out_backprop, const ConvOpAttrs& attrs) { + TF_RETURN_IF_ERROR(CheckConvAttrs(attrs)); + + int num_dims = attrs.num_spatial_dims + 2; + int batch_dim = GetTensorBatchDimIndex(num_dims, attrs.data_format); + int feature_dim = GetTensorFeatureDimIndex(num_dims, attrs.data_format); + + auto* builder = filter.builder(); + TF_ASSIGN_OR_RETURN(xla::Shape filter_shape, builder->GetShape(filter)); + TF_ASSIGN_OR_RETURN(xla::Shape out_backprop_shape, + builder->GetShape(out_backprop)); + + xla::Shape expanded_filter_shape = + attrs.depthwise ? ExpandedFilterShapeForDepthwiseConvolution(filter_shape) + : filter_shape; + // Reuse dimension computation logic from conv_grad_ops.cc. + ConvBackpropDimensions dims; + TF_RETURN_IF_ERROR(ConvBackpropComputeDimensionsV2XlaShapes( + type_string, attrs.num_spatial_dims, input_shape, expanded_filter_shape, + out_backprop_shape, attrs.dilations, attrs.strides, attrs.padding, + attrs.data_format, &dims)); + + // The input gradients are computed by a convolution of the output + // gradients and the filter, with some appropriate padding. See the + // comment at the top of conv_grad_ops.h for details. + + xla::ConvolutionDimensionNumbers dnums; + dnums.set_input_batch_dimension(batch_dim); + dnums.set_output_batch_dimension(batch_dim); + dnums.set_input_feature_dimension(feature_dim); + dnums.set_output_feature_dimension(feature_dim); + + // TF filter shape is [ H, W, ..., inC, outC ] + // Transpose the input and output features for computing the gradient. + dnums.set_kernel_input_feature_dimension(attrs.num_spatial_dims + 1); + dnums.set_kernel_output_feature_dimension(attrs.num_spatial_dims); + + std::vector kernel_spatial_dims(attrs.num_spatial_dims); + std::vector> padding(attrs.num_spatial_dims); + std::vector lhs_dilation(attrs.num_spatial_dims); + std::vector rhs_dilation(attrs.num_spatial_dims); + std::vector ones(attrs.num_spatial_dims, 1); + for (int i = 0; i < attrs.num_spatial_dims; ++i) { + int64 dim = GetTensorSpatialDimIndex(num_dims, attrs.data_format, i); + dnums.add_input_spatial_dimensions(dim); + dnums.add_kernel_spatial_dimensions(i); + dnums.add_output_spatial_dimensions(dim); + + kernel_spatial_dims[i] = i; + padding[i] = {dims.spatial_dims[i].pad_before, + dims.spatial_dims[i].pad_after}; + lhs_dilation[i] = dims.spatial_dims[i].stride; + rhs_dilation[i] = attrs.dilations[dim]; + } + + // Mirror the filter in the spatial dimensions. + xla::XlaOp mirrored_weights = xla::Rev(filter, kernel_spatial_dims); + + // activation gradients + // = gradients (with padding and dilation) mirrored_weights + return xla::ConvGeneralDilated( + out_backprop, mirrored_weights, /*window_strides=*/ones, padding, + lhs_dilation, rhs_dilation, dnums, + /*feature_group_count=*/ + attrs.depthwise ? out_backprop_shape.dimensions(feature_dim) / + filter_shape.dimensions(attrs.num_spatial_dims + 1) + : 1); +} + +xla::StatusOr MakeXlaBackpropFilterConvOp( + StringPiece type_string, xla::XlaOp activations, + const xla::Shape& filter_shape, xla::XlaOp gradients, + const ConvOpAttrs& attrs) { + TF_RETURN_IF_ERROR(CheckConvAttrs(attrs)); + + auto* builder = activations.builder(); + TF_ASSIGN_OR_RETURN(xla::Shape activations_shape, + builder->GetShape(activations)); + TF_ASSIGN_OR_RETURN(xla::Shape out_backprop_shape, + builder->GetShape(gradients)); + const xla::Shape expanded_filter_shape = + attrs.depthwise ? ExpandedFilterShapeForDepthwiseConvolution(filter_shape) + : filter_shape; + + // Reuse dimension computation logic from conv_grad_ops.cc. + ConvBackpropDimensions dims; + TF_RETURN_IF_ERROR(ConvBackpropComputeDimensionsV2XlaShapes( + type_string, attrs.num_spatial_dims, activations_shape, + expanded_filter_shape, out_backprop_shape, attrs.dilations, attrs.strides, + attrs.padding, attrs.data_format, &dims)); + + // The filter gradients are computed by a convolution of the input + // activations and the output gradients, with some appropriate padding. + // See the comment at the top of conv_grad_ops.h for details. + + xla::ConvolutionDimensionNumbers dnums; + + // The activations (inputs) form the LHS of the convolution. + // Activations have shape: [batch, in_rows, in_cols, ..., in_depth] + // For the gradient computation, we flip the roles of the batch and + // feature dimensions. + // Each spatial entry has size in_depth * batch + + // The last two dimensions of the filter are the input and output shapes. + int num_dims = attrs.num_spatial_dims + 2; + int n_dim = GetTensorBatchDimIndex(num_dims, attrs.data_format); + int c_dim = GetTensorFeatureDimIndex(num_dims, attrs.data_format); + + // Swap n_dim and c_dim in the activations. + dnums.set_input_batch_dimension(c_dim); + dnums.set_input_feature_dimension(n_dim); + + // The gradients become the RHS of the convolution. + // The gradients have shape [batch, out_rows, out_cols, ..., out_depth] + // where the batch becomes the input feature for the convolution. + dnums.set_kernel_input_feature_dimension(n_dim); + dnums.set_kernel_output_feature_dimension(c_dim); + + std::vector> padding(attrs.num_spatial_dims); + std::vector rhs_dilation(attrs.num_spatial_dims); + std::vector window_strides(attrs.num_spatial_dims); + std::vector ones(attrs.num_spatial_dims, 1); + + // Tensorflow filter shape is [ H, W, ..., inC, outC ]. + for (int i = 0; i < attrs.num_spatial_dims; ++i) { + dnums.add_output_spatial_dimensions(i); + } + dnums.set_output_batch_dimension(attrs.num_spatial_dims); + dnums.set_output_feature_dimension(attrs.num_spatial_dims + 1); + + for (int i = 0; i < attrs.num_spatial_dims; ++i) { + int64 dim = GetTensorSpatialDimIndex(num_dims, attrs.data_format, i); + dnums.add_input_spatial_dimensions(dim); + dnums.add_kernel_spatial_dimensions(dim); + + // We will also need to pad the input with zeros such that after the + // convolution, we get the right size for the filter. + // The padded_in_rows should be such that when we convolve this with the + // expanded_out_rows as a filter, we should get filter_rows back. + // + const int64 padded_in_size = + dims.spatial_dims[i].expanded_output_size + + (dims.spatial_dims[i].filter_size - 1) * attrs.dilations[dim]; + + // However it can be smaller than input_rows: in this + // case it means some of the inputs are not used. + // + // An example is to have input_cols = 3, filter_cols = 2 and stride = 2: + // + // INPUT = [ A B C ] + // + // FILTER = [ x y ] + // + // and the output will only have one column: a = A * x + B * y + // + // and input "C" is not used at all. + // + // We apply negative padding in this case. + const int64 pad_total = padded_in_size - dims.spatial_dims[i].input_size; + + // + For the VALID padding, we don't pad anything on the top/left side + // and pad the bottom/right side with the remaining space. + // + For the SAME padding, we pad top/left side the same as bottom/right + // side. + // + // In addition, if the padded input size is smaller than the input size, + // we need to ignore some training elements of the input. We do this by + // applying negative padding on the right/bottom. + const int64 pad_before = + attrs.padding == Padding::SAME ? std::max(pad_total / 2, 0) : 0; + + padding[i] = {pad_before, pad_total - pad_before}; + rhs_dilation[i] = dims.spatial_dims[i].stride; + window_strides[i] = attrs.dilations[dim]; + } + + // Besides padding the input, we will also expand output_rows to + // expanded_out_rows = (output_rows - 1) * stride + 1 + // with zeros in between: + // + // a . . . b . . . c . . . d . . . e + // + // This is done by specifying the window dilation factors in the + // convolution HLO below. + auto filter_backprop = + xla::ConvGeneralDilated(activations, gradients, window_strides, padding, + /*lhs_dilation=*/ones, rhs_dilation, dnums); + + if (attrs.depthwise) { + filter_backprop = ContractFilterForDepthwiseBackprop( + filter_shape, filter_backprop, activations.builder()); + } + + return filter_backprop; +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h new file mode 100644 index 0000000000000000000000000000000000000000..6e1b70a47850ae5c05939f8dfb7ec129c031df21 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h @@ -0,0 +1,69 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_CONV_OP_HELPERS_H_ +#define TENSORFLOW_COMPILER_TF2XLA_KERNELS_CONV_OP_HELPERS_H_ + +#include + +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/util/padding.h" +#include "tensorflow/core/util/tensor_format.h" + +// This header exposes utilities for translating TensorFlow convolution ops into +// XLA ops. +// +// conv_ops.cc contains lowerings for many of these TF convolution ops (e.g. +// Conv2D, Conv3DBackpropFilterV2), but you might want to use the utilities in +// this header to implement a new and exciting convolution op, for example a +// fused TensorFlow op that contains a convolution and other things. + +namespace tensorflow { + +// ConvOpAttrs contains all of the metadata necessary to specify a TF or XLA +// convolution. +struct ConvOpAttrs { + // Constructs a ConvOpAttrs, reading most of the attributes from `ctx`. + static xla::StatusOr Create(int num_spatial_dims, bool depthwise, + OpKernelConstruction* ctx); + + bool depthwise; + int num_spatial_dims; + std::vector dilations; + std::vector strides; + Padding padding; + TensorFormat data_format; +}; + +// Creates a new XLA forward or backward convolution with the given inputs and +// attributes. +xla::StatusOr MakeXlaForwardConvOp(StringPiece type_string, + xla::XlaOp conv_input, + xla::XlaOp filter, + const ConvOpAttrs& attrs); +xla::StatusOr MakeXlaBackpropInputConvOp( + StringPiece type_string, const xla::Shape& input_shape, xla::XlaOp filter, + xla::XlaOp out_backprop, const ConvOpAttrs& attrs); +xla::StatusOr MakeXlaBackpropFilterConvOp( + StringPiece type_string, xla::XlaOp activations, + const xla::Shape& filter_shape, xla::XlaOp gradients, + const ConvOpAttrs& attrs); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_KERNELS_CONV_OP_HELPERS_H_ diff --git a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc index 674720e22fbf9d995e74c7dbd0ef7d7765941867..cd7c820be0b6029514ff74288e7bdd3f75b5d6b1 100644 --- a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc @@ -15,12 +15,17 @@ limitations under the License. // XLA-specific Ops for 2D convolution. +#include "tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/lib/numeric.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" @@ -33,250 +38,28 @@ limitations under the License. #include "tensorflow/core/util/tensor_format.h" namespace tensorflow { - namespace { -// Returns the expanded size of a filter used for depthwise convolution. -// If `shape` is [H, W, ..., M, N] returns [H, W, ..., M, M*N]. -TensorShape ExpandedFilterShapeForDepthwiseConvolution( - const TensorShape& shape) { - int num_dims = shape.dims(); - CHECK_GE(num_dims, 2); - TensorShape expanded_shape = shape; - expanded_shape.set_dim(num_dims - 1, shape.dim_size(num_dims - 2) * - shape.dim_size(num_dims - 1)); - return expanded_shape; -} - -// Broadcast zeros to ExpandedFilterShapeForDepthwiseConvolution. -xla::XlaOp CreateExpandedZero(const TensorShape& filter_shape, DataType dtype, - xla::XlaBuilder* builder) { - TensorShape expanded_filter_shape = - ExpandedFilterShapeForDepthwiseConvolution(filter_shape); - return xla::Broadcast(XlaHelpers::Zero(builder, dtype), - expanded_filter_shape.dim_sizes()); -} - -// Create a mask for depthwise convolution that will make a normal convolution -// produce the same results as a depthwise convolution. For a [2, 2, 3, 2] -// depthwise filter this returns a [2, 2, 3, 6] tensor -// 1 1 0 0 0 0 1 1 0 0 0 0 -// 0 0 1 1 0 0 0 0 1 1 0 0 -// 0 0 0 0 1 1 0 0 0 0 1 1 -// -// 1 1 0 0 0 0 1 1 0 0 0 0 -// 0 0 1 1 0 0 0 0 1 1 0 0 -// 0 0 0 0 1 1 0 0 0 0 1 1 -// -// The first step is to create a one tensor, A, that is [3] -// 0 1 2 -// -// and another tensor, B, that is [3 * 2] -// 0 1 2 3 4 5 -// -// and divide B it by 2 to get -// 0 0 1 1 2 2 -// -// then we broadcast the B to [2, 2, 3, 3 * 2] -// 0 0 1 1 2 2 0 0 1 1 2 2 -// 0 0 1 1 2 2 0 0 1 1 2 2 -// 0 0 1 1 2 2 0 0 1 1 2 2 -// -// 0 0 1 1 2 2 0 0 1 1 2 2 -// 0 0 1 1 2 2 0 0 1 1 2 2 -// 0 0 1 1 2 2 0 0 1 1 2 2 -// -// Finally compare A and broadcasted B in dimension 2 amd return the result at -// the beginning of the comment. -xla::XlaOp CreateExpandedFilterMask(const TensorShape& filter_shape, - xla::XlaBuilder* builder) { - TensorShape expanded_filter_shape = - ExpandedFilterShapeForDepthwiseConvolution(filter_shape); - int64 depthwise_multiplier = filter_shape.dim_size(filter_shape.dims() - 1); - int64 input_feature = filter_shape.dim_size(filter_shape.dims() - 2); - - // Create a M sized linspace and an M*N sized linspace that will be - // broadcasted into perpendicular dimensions and compared. - xla::XlaOp input_feature_iota = xla::Iota(builder, xla::S32, input_feature); - xla::XlaOp expanded_feature_iota = - xla::Iota(builder, xla::S32, input_feature * depthwise_multiplier); - - // Divide the M*N sized linspace by the depthwise_multiplier to create - // [0 0 1 1 2 2] in the example in the function comment. - expanded_feature_iota = - xla::Div(expanded_feature_iota, - XlaHelpers::IntegerLiteral(builder, DataType::DT_INT32, - depthwise_multiplier)); - - // Broadcast the N*M linspace to [H, W, ..., M, M*N]. - auto expanded_feature_broadcast_dims = expanded_filter_shape.dim_sizes(); - expanded_feature_broadcast_dims.pop_back(); - auto broadcasted_expanded_feature_iota = - xla::Broadcast(expanded_feature_iota, expanded_feature_broadcast_dims); - - // Compare the broadcasted linspace to the input feature linspace in the - // input feature dimension to create a diagonal predicate. - return xla::Eq(broadcasted_expanded_feature_iota, input_feature_iota, - {expanded_filter_shape.dims() - 2}); -} - -// Reshapes a filter of shape [H, W, ..., M, N] to [H, W, ..., 1, M*N]. Used to -// build a depthwise convolution. -xla::XlaOp ReshapeFilterForDepthwiseConvolution(const TensorShape& filter_shape, - const xla::XlaOp& filter) { - int64 input_feature_dim = filter_shape.dims() - 2; - int64 output_feature_dim = filter_shape.dims() - 1; - int64 depthwise_multiplier = filter_shape.dim_size(output_feature_dim); - int64 input_feature = filter_shape.dim_size(input_feature_dim); - - // Create a [H, W, ..., 1, N*M] reshape of the filter. - TensorShape implicit_broadcast_filter_shape = filter_shape; - implicit_broadcast_filter_shape.set_dim(input_feature_dim, 1); - implicit_broadcast_filter_shape.set_dim(output_feature_dim, - depthwise_multiplier * input_feature); - return xla::Reshape(filter, implicit_broadcast_filter_shape.dim_sizes()); -} - -// Reduces the results of the convolution with an expanded filter to the -// non-expanded filter. -xla::XlaOp ContractFilterForDepthwiseBackprop(XlaOpKernelContext* ctx, - const TensorShape& filter_shape, - DataType dtype, - const xla::XlaOp& filter_backprop, - xla::XlaBuilder* builder) { - auto masked_expanded_filter = xla::Select( - CreateExpandedFilterMask(filter_shape, builder), filter_backprop, - CreateExpandedZero(filter_shape, dtype, builder)); - return xla::Reshape( - // This reduce does not need inputs to be converted with - // XlaHelpers::SumAccumulationType() since the ExpandedFilterMask with - // ExpandedZero guarantees that only one element is non zero, so there - // cannot be accumulated precision error. - xla::Reduce(masked_expanded_filter, XlaHelpers::Zero(builder, dtype), - *ctx->GetOrCreateAdd(dtype), {filter_shape.dims() - 2}), - filter_shape.dim_sizes()); -} - class ConvOp : public XlaOpKernel { public: explicit ConvOp(OpKernelConstruction* ctx, int num_spatial_dims, bool depthwise) - : XlaOpKernel(ctx), - num_spatial_dims_(num_spatial_dims), - depthwise_(depthwise) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("dilations", &dilations_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &strides_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_)); - - string data_format; - OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format)); - OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_), - errors::InvalidArgument("Invalid data format")); + : XlaOpKernel(ctx) { + xla::StatusOr attrs = + ConvOpAttrs::Create(num_spatial_dims, depthwise, ctx); + OP_REQUIRES_OK(ctx, attrs.status()); + attrs_ = attrs.ValueOrDie(); } - int num_dims() const { return num_spatial_dims_ + 2; } - void Compile(XlaOpKernelContext* ctx) override { - OP_REQUIRES(ctx, strides_.size() == num_dims(), - errors::InvalidArgument("Sliding window strides field must " - "specify ", - num_dims(), " dimensions")); - int batch_dim = GetTensorBatchDimIndex(num_dims(), data_format_); - int feature_dim = GetTensorFeatureDimIndex(num_dims(), data_format_); - OP_REQUIRES( - ctx, strides_[batch_dim] == 1 && strides_[feature_dim] == 1, - errors::Unimplemented("Current implementation does not yet support " - "strides in the batch and depth dimensions.")); - - OP_REQUIRES(ctx, dilations_.size() == num_dims(), - errors::InvalidArgument("Dilations field must " - "specify ", - num_dims(), " dimensions")); - OP_REQUIRES( - ctx, dilations_[batch_dim] == 1 && dilations_[feature_dim] == 1, - errors::Unimplemented("Current implementation does not support " - "dilations in the batch and depth dimensions.")); - for (int i = 0; i < num_spatial_dims_; ++i) { - int input_dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i); - OP_REQUIRES(ctx, dilations_[input_dim] >= 1, - errors::Unimplemented("Dilation values must be positive; ", i, - "th spatial dimension had dilation ", - dilations_[input_dim])); - } - - const TensorShape input_shape = ctx->InputShape(0); - // Input filter is of the following dimensions: - // [ filter_rows, filter_cols, ..., in_depth, out_depth] - const TensorShape filter_shape = ctx->InputShape(1); - - // For 2D convolution, there should be 4 dimensions. - OP_REQUIRES( - ctx, input_shape.dims() == num_dims(), - errors::InvalidArgument("input must be ", num_dims(), "-dimensional", - input_shape.DebugString())); - OP_REQUIRES( - ctx, filter_shape.dims() == num_dims(), - errors::InvalidArgument("filter must be ", num_dims(), - "-dimensional: ", filter_shape.DebugString())); - - // The last two dimension of the filter are the input and output shapes. - const int64 in_depth = filter_shape.dim_size(num_spatial_dims_); - - // The 'C' dimension for input is in_depth. It must be the same as - // the filter's in_depth. - OP_REQUIRES(ctx, in_depth == input_shape.dim_size(feature_dim), - errors::InvalidArgument( - "input and filter must have the same depth: ", in_depth, - " vs ", input_shape.dim_size(feature_dim))); - - xla::XlaOp filter = ctx->Input(1); - if (depthwise_) { - filter = ReshapeFilterForDepthwiseConvolution(filter_shape, filter); - } - - xla::ConvolutionDimensionNumbers dims; - std::vector window_strides(num_spatial_dims_); - std::vector lhs_dilation(num_spatial_dims_, 1); - std::vector rhs_dilation(num_spatial_dims_); - std::vector> padding(num_spatial_dims_); - - dims.set_input_batch_dimension(batch_dim); - dims.set_output_batch_dimension(batch_dim); - dims.set_input_feature_dimension(feature_dim); - dims.set_output_feature_dimension(feature_dim); - dims.set_kernel_input_feature_dimension(num_spatial_dims_); - dims.set_kernel_output_feature_dimension(num_spatial_dims_ + 1); - - for (int i = 0; i < num_spatial_dims_; ++i) { - const int64 dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i); - dims.add_input_spatial_dimensions(dim); - dims.add_kernel_spatial_dimensions(i); - dims.add_output_spatial_dimensions(dim); - window_strides[i] = strides_.at(dim); - rhs_dilation[i] = dilations_.at(dim); - - int64 unused_output_size; - OP_REQUIRES_OK( - ctx, GetWindowedOutputSizeVerboseV2( - input_shape.dim_size(dim), filter_shape.dim_size(i), - rhs_dilation[i], window_strides[i], padding_, - &unused_output_size, &padding[i].first, &padding[i].second)); - } - - xla::XlaOp conv = xla::ConvGeneralDilated( - ctx->Input(0), filter, window_strides, padding, lhs_dilation, - rhs_dilation, dims, - /*feature_group_count=*/depthwise_ ? in_depth : 1); - ctx->SetOutput(0, conv); + xla::StatusOr conv = MakeXlaForwardConvOp( + ctx->op_kernel().type_string(), ctx->Input(0), ctx->Input(1), attrs_); + OP_REQUIRES_OK(ctx, conv.status()); + ctx->SetOutput(0, conv.ValueOrDie()); } protected: - const int num_spatial_dims_; - const bool depthwise_; - std::vector dilations_; - std::vector strides_; - Padding padding_; - TensorFormat data_format_ = FORMAT_NHWC; + ConvOpAttrs attrs_; private: TF_DISALLOW_COPY_AND_ASSIGN(ConvOp); @@ -308,124 +91,28 @@ class ConvBackpropInputOp : public XlaOpKernel { public: explicit ConvBackpropInputOp(OpKernelConstruction* ctx, int num_spatial_dims, bool depthwise) - : XlaOpKernel(ctx), - num_spatial_dims_(num_spatial_dims), - depthwise_(depthwise) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("dilations", &dilations_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &strides_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_)); - string data_format; - OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format)); - OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_), - errors::InvalidArgument("Invalid data format")); + : XlaOpKernel(ctx) { + xla::StatusOr attrs = + ConvOpAttrs::Create(num_spatial_dims, depthwise, ctx); + OP_REQUIRES_OK(ctx, attrs.status()); + attrs_ = attrs.ValueOrDie(); } - int num_dims() const { return num_spatial_dims_ + 2; } - void Compile(XlaOpKernelContext* ctx) override { - OP_REQUIRES(ctx, strides_.size() == num_dims(), - errors::InvalidArgument("Sliding window strides field must " - "specify ", - num_dims(), " dimensions")); - int batch_dim = GetTensorBatchDimIndex(num_dims(), data_format_); - int feature_dim = GetTensorFeatureDimIndex(num_dims(), data_format_); - OP_REQUIRES( - ctx, strides_[batch_dim] == 1 && strides_[feature_dim] == 1, - errors::Unimplemented("Current implementation does not yet support " - "strides in the batch and depth dimensions.")); - - OP_REQUIRES(ctx, dilations_.size() == num_dims(), - errors::InvalidArgument("Dilations field must " - "specify ", - num_dims(), " dimensions")); - OP_REQUIRES( - ctx, dilations_[batch_dim] == 1 && dilations_[feature_dim] == 1, - errors::Unimplemented("Current implementation does not support " - "dilations in the batch and depth dimensions.")); - for (int i = 0; i < num_spatial_dims_; ++i) { - int input_dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i); - OP_REQUIRES(ctx, dilations_[input_dim] >= 1, - errors::Unimplemented("Dilation values must be positive; ", i, - "th spatial dimension had dilation ", - dilations_[input_dim])); - } - - TensorShape input_shape; - OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &input_shape)); - - const TensorShape filter_shape = ctx->InputShape(1); - const TensorShape out_backprop_shape = ctx->InputShape(2); - - const TensorShape expanded_filter_shape = - depthwise_ ? ExpandedFilterShapeForDepthwiseConvolution(filter_shape) - : filter_shape; - // Reuse dimension computation logic from conv_grad_ops.cc. - ConvBackpropDimensions dims; - OP_REQUIRES_OK(ctx, - ConvBackpropComputeDimensionsV2( - type_string(), num_spatial_dims_, input_shape, - expanded_filter_shape, out_backprop_shape, dilations_, - strides_, padding_, data_format_, &dims)); - - auto filter = ctx->Input(1); - auto out_backprop = ctx->Input(2); - - // The input gradients are computed by a convolution of the output - // gradients and the filter, with some appropriate padding. See the - // comment at the top of conv_grad_ops.h for details. - - xla::ConvolutionDimensionNumbers dnums; - dnums.set_input_batch_dimension(batch_dim); - dnums.set_output_batch_dimension(batch_dim); - dnums.set_input_feature_dimension(feature_dim); - dnums.set_output_feature_dimension(feature_dim); - - // TF filter shape is [ H, W, ..., inC, outC ] - // Transpose the input and output features for computing the gradient. - dnums.set_kernel_input_feature_dimension(num_spatial_dims_ + 1); - dnums.set_kernel_output_feature_dimension(num_spatial_dims_); - - std::vector kernel_spatial_dims(num_spatial_dims_); - std::vector> padding(num_spatial_dims_); - std::vector lhs_dilation(num_spatial_dims_); - std::vector rhs_dilation(num_spatial_dims_); - std::vector ones(num_spatial_dims_, 1); - for (int i = 0; i < num_spatial_dims_; ++i) { - int64 dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i); - dnums.add_input_spatial_dimensions(dim); - dnums.add_kernel_spatial_dimensions(i); - dnums.add_output_spatial_dimensions(dim); - - kernel_spatial_dims[i] = i; - padding[i] = {dims.spatial_dims[i].pad_before, - dims.spatial_dims[i].pad_after}; - lhs_dilation[i] = dims.spatial_dims[i].stride; - rhs_dilation[i] = dilations_[dim]; - } - - // Mirror the filter in the spatial dimensions. - xla::XlaOp mirrored_weights = xla::Rev(filter, kernel_spatial_dims); - - // activation gradients - // = gradients (with padding and dilation) mirrored_weights - xla::XlaOp in_backprop = xla::ConvGeneralDilated( - out_backprop, mirrored_weights, /*window_strides=*/ones, padding, - lhs_dilation, rhs_dilation, dnums, - /*feature_group_count=*/ - depthwise_ ? out_backprop_shape.dim_size(feature_dim) / - filter_shape.dim_size(num_spatial_dims_ + 1) - : 1); - - ctx->SetOutput(0, in_backprop); + TensorShape input_tensor_shape; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &input_tensor_shape)); + xla::Shape input_shape = + TensorShapeToXLAShape(ctx->input_xla_type(1), input_tensor_shape); + + xla::StatusOr in_backprop = + MakeXlaBackpropInputConvOp(ctx->op_kernel().type_string(), input_shape, + ctx->Input(1), ctx->Input(2), attrs_); + OP_REQUIRES_OK(ctx, in_backprop.status()); + ctx->SetOutput(0, in_backprop.ValueOrDie()); } protected: - const int num_spatial_dims_; - const bool depthwise_; - std::vector dilations_; - std::vector strides_; - Padding padding_; - TensorFormat data_format_ = FORMAT_NHWC; + ConvOpAttrs attrs_; private: TF_DISALLOW_COPY_AND_ASSIGN(ConvBackpropInputOp); @@ -462,172 +149,28 @@ class ConvBackpropFilterOp : public XlaOpKernel { public: explicit ConvBackpropFilterOp(OpKernelConstruction* ctx, int num_spatial_dims, bool depthwise) - : XlaOpKernel(ctx), - num_spatial_dims_(num_spatial_dims), - depthwise_(depthwise) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("dilations", &dilations_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &strides_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_)); - string data_format; - OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format)); - OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_), - errors::InvalidArgument("Invalid data format")); + : XlaOpKernel(ctx) { + xla::StatusOr attrs = + ConvOpAttrs::Create(num_spatial_dims, depthwise, ctx); + OP_REQUIRES_OK(ctx, attrs.status()); + attrs_ = attrs.ValueOrDie(); } - int num_dims() const { return num_spatial_dims_ + 2; } - void Compile(XlaOpKernelContext* ctx) override { - const int n_dim = GetTensorBatchDimIndex(num_dims(), data_format_); - const int c_dim = GetTensorFeatureDimIndex(num_dims(), data_format_); - - OP_REQUIRES( - ctx, (strides_[n_dim] == 1 && strides_[c_dim] == 1), - errors::InvalidArgument("Current implementation does not yet support " - "strides in the batch and depth dimensions.")); - - OP_REQUIRES(ctx, dilations_.size() == num_dims(), - errors::InvalidArgument("Dilations field must " - "specify ", - num_dims(), " dimensions")); - OP_REQUIRES( - ctx, dilations_[n_dim] == 1 && dilations_[c_dim] == 1, - errors::Unimplemented("Current implementation does not support " - "dilations in the batch and depth dimensions.")); - for (int i = 0; i < num_spatial_dims_; ++i) { - int input_dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i); - OP_REQUIRES(ctx, dilations_[input_dim] >= 1, - errors::Unimplemented("Dilation values must be positive; ", i, - "th spatial dimension had dilation ", - dilations_[input_dim])); - } - - const TensorShape activations_shape = ctx->InputShape(0); - TensorShape filter_shape; - OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(1, &filter_shape)); - const TensorShape out_backprop_shape = ctx->InputShape(2); - - const TensorShape expanded_filter_shape = - depthwise_ ? ExpandedFilterShapeForDepthwiseConvolution(filter_shape) - : filter_shape; - - // Reuse dimension computation logic from conv_grad_ops.cc. - ConvBackpropDimensions dims; - OP_REQUIRES_OK(ctx, - ConvBackpropComputeDimensionsV2( - type_string(), num_spatial_dims_, activations_shape, - expanded_filter_shape, out_backprop_shape, dilations_, - strides_, padding_, data_format_, &dims)); - - xla::XlaBuilder* b = ctx->builder(); - xla::XlaOp activations = ctx->Input(0); - xla::XlaOp gradients = ctx->Input(2); - - // The filter gradients are computed by a convolution of the input - // activations and the output gradients, with some appropriate padding. - // See the comment at the top of conv_grad_ops.h for details. - - xla::ConvolutionDimensionNumbers dnums; - - // The activations (inputs) form the LHS of the convolution. - // Activations have shape: [batch, in_rows, in_cols, ..., in_depth] - // For the gradient computation, we flip the roles of the batch and - // feature dimensions. - // Each spatial entry has size in_depth * batch - - // Swap n_dim and c_dim in the activations. - dnums.set_input_batch_dimension(c_dim); - dnums.set_input_feature_dimension(n_dim); - - // The gradients become the RHS of the convolution. - // The gradients have shape [batch, out_rows, out_cols, ..., out_depth] - // where the batch becomes the input feature for the convolution. - dnums.set_kernel_input_feature_dimension(n_dim); - dnums.set_kernel_output_feature_dimension(c_dim); - - std::vector> padding(num_spatial_dims_); - std::vector rhs_dilation(num_spatial_dims_); - std::vector window_strides(num_spatial_dims_); - std::vector ones(num_spatial_dims_, 1); - - // Tensorflow filter shape is [ H, W, ..., inC, outC ]. - for (int i = 0; i < num_spatial_dims_; ++i) { - dnums.add_output_spatial_dimensions(i); - } - dnums.set_output_batch_dimension(num_spatial_dims_); - dnums.set_output_feature_dimension(num_spatial_dims_ + 1); - - for (int i = 0; i < num_spatial_dims_; ++i) { - int64 dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i); - dnums.add_input_spatial_dimensions(dim); - dnums.add_kernel_spatial_dimensions(dim); - - // We will also need to pad the input with zeros such that after the - // convolution, we get the right size for the filter. - // The padded_in_rows should be such that when we convolve this with the - // expanded_out_rows as a filter, we should get filter_rows back. - // - const int64 padded_in_size = - dims.spatial_dims[i].expanded_output_size + - (dims.spatial_dims[i].filter_size - 1) * dilations_[dim]; - - // However it can be smaller than input_rows: in this - // case it means some of the inputs are not used. - // - // An example is to have input_cols = 3, filter_cols = 2 and stride = 2: - // - // INPUT = [ A B C ] - // - // FILTER = [ x y ] - // - // and the output will only have one column: a = A * x + B * y - // - // and input "C" is not used at all. - // - // We apply negative padding in this case. - const int64 pad_total = padded_in_size - dims.spatial_dims[i].input_size; - - // + For the VALID padding, we don't pad anything on the top/left side - // and pad the bottom/right side with the remaining space. - // + For the SAME padding, we pad top/left side the same as bottom/right - // side. - // - // In addition, if the padded input size is smaller than the input size, - // we need to ignore some training elements of the input. We do this by - // applying negative padding on the right/bottom. - const int64 pad_before = - padding_ == Padding::SAME ? std::max(pad_total / 2, 0) : 0; - - padding[i] = {pad_before, pad_total - pad_before}; - rhs_dilation[i] = dims.spatial_dims[i].stride; - window_strides[i] = dilations_[dim]; - } - - // Besides padding the input, we will also expand output_rows to - // expanded_out_rows = (output_rows - 1) * stride + 1 - // with zeros in between: - // - // a . . . b . . . c . . . d . . . e - // - // This is done by specifying the window dilation factors in the - // convolution HLO below. - auto filter_backprop = - xla::ConvGeneralDilated(activations, gradients, window_strides, padding, - /*lhs_dilation=*/ones, rhs_dilation, dnums); - - if (depthwise_) { - filter_backprop = ContractFilterForDepthwiseBackprop( - ctx, filter_shape, ctx->input_type(0), filter_backprop, b); - } - ctx->SetOutput(0, filter_backprop); + TensorShape filter_tensor_shape; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(1, &filter_tensor_shape)); + xla::Shape filter_shape = + TensorShapeToXLAShape(ctx->input_xla_type(0), filter_tensor_shape); + + xla::StatusOr filter_backprop = MakeXlaBackpropFilterConvOp( + ctx->op_kernel().type_string(), ctx->Input(0), filter_shape, + ctx->Input(2), attrs_); + OP_REQUIRES_OK(ctx, filter_backprop.status()); + ctx->SetOutput(0, filter_backprop.ValueOrDie()); } protected: - const int num_spatial_dims_; - const bool depthwise_; - std::vector dilations_; - std::vector strides_; - Padding padding_; - TensorFormat data_format_ = FORMAT_NHWC; + ConvOpAttrs attrs_; private: TF_DISALLOW_COPY_AND_ASSIGN(ConvBackpropFilterOp); diff --git a/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc b/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc index ef1015552d181a183d412f9c269dd5ec608b388f..234f7b4a019c9aac4bac4f906ddbae166ecd9a80 100644 --- a/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h" +#include "tensorflow/compiler/tf2xla/lib/broadcast.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" @@ -39,7 +40,8 @@ void XlaBinaryOp::Compile(XlaOpKernelContext* ctx) { // compute valid broadcast shapes, but rely below on XLA to // automatically perform the broadcast assuming its valid shapes are // a superset of TensorFlow's valid shapes. - BCast bcast(BCast::FromShape(lhs_shape), BCast::FromShape(rhs_shape)); + BCast bcast(BCast::FromShape(lhs_shape), BCast::FromShape(rhs_shape), + /*fewer_dims_optimization=*/false); if (!bcast.IsValid()) { ctx->SetStatus(errors::InvalidArgument("Incompatible shapes: ", lhs_shape.DebugString(), " vs. ", @@ -86,51 +88,18 @@ void XlaBinaryOp::Compile(XlaOpKernelContext* ctx) { } /* static */ std::pair XlaBinaryOp::Broadcast( - xla::XlaBuilder* builder, const xla::XlaOp& lhs, const xla::XlaOp& rhs, - const BCast& broadcast_helper) { - // Manually construct the broadcasting since MapN does not do - // automatic broadcasting. The bcast helper ensures that - // lhs.reshape(bcast.x_reshape()).broadcast(bcast.x_bcast()) and - // rhs.reshape(bcast.y_reshape()).broadcast(bcast.y_bcast()) have - // the same shape, so can be operated on by MapN. - - // First reshape the inputs, which should be a metadata-only - // operation since we are flattening the dimensions in order. - auto lhs_shaped = xla::Reshape(lhs, broadcast_helper.x_reshape()); - auto rhs_shaped = xla::Reshape(rhs, broadcast_helper.y_reshape()); - - // Next broadcast the necessary input dimensions. We rely on the - // XLA optimizer to be smart about the fact that we are asking - // it to broadcast size 1 on some of these dimensions, to avoid - // adding complexity to this code. - auto lhs_broadcast = xla::Broadcast(lhs_shaped, broadcast_helper.x_bcast()); - int lhs_size = broadcast_helper.x_bcast().size(); - auto rhs_broadcast = xla::Broadcast(rhs_shaped, broadcast_helper.y_bcast()); - int rhs_size = broadcast_helper.y_bcast().size(); - - // Now reshape them to the correct output shape. After the - // broadcast each side is twice as wide as it should be, since the - // broadcast dimensions were prepended to the shape. Reshape - // flattening each original dimension with the prepended broadcast - // dimension. E.g. if we started out with lhs_shaped with shape - // [5,2,3] and x_bcast was [2,1,7] then lhs_broadcast would have - // shape [2,1,7,5,2,3] and we want to reshape it to [10,2,21]. - std::vector lhs_reorder; - for (int i = 0; i < lhs_size; ++i) { - lhs_reorder.push_back(i); - lhs_reorder.push_back(i + lhs_size); + xla::XlaOp lhs, xla::XlaOp rhs, const BCast& broadcast_helper) { + auto lhs_output = BroadcastTo(lhs, broadcast_helper.output_shape()); + if (!lhs_output.ok()) { + xla::XlaOp error = lhs.builder()->ReportError(lhs_output.status()); + return {error, error}; } - auto lhs_output = - xla::Reshape(lhs_broadcast, lhs_reorder, broadcast_helper.output_shape()); - std::vector rhs_reorder; - for (int i = 0; i < rhs_size; ++i) { - rhs_reorder.push_back(i); - rhs_reorder.push_back(i + rhs_size); + auto rhs_output = BroadcastTo(rhs, broadcast_helper.output_shape()); + if (!rhs_output.ok()) { + xla::XlaOp error = rhs.builder()->ReportError(rhs_output.status()); + return {error, error}; } - auto rhs_output = - xla::Reshape(rhs_broadcast, rhs_reorder, broadcast_helper.output_shape()); - - return {lhs_output, rhs_output}; + return {lhs_output.ValueOrDie(), rhs_output.ValueOrDie()}; } } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/cwise_ops.h b/tensorflow/compiler/tf2xla/kernels/cwise_ops.h index a5b870f8dbf70bcee331992345d63fd5d986bdca..516ead4bfe89b4ddeee11dcc6410a838d04f28a9 100644 --- a/tensorflow/compiler/tf2xla/kernels/cwise_ops.h +++ b/tensorflow/compiler/tf2xla/kernels/cwise_ops.h @@ -57,8 +57,8 @@ class XlaBinaryOp : public XlaOpKernel { // in the XLA documentation. virtual xla::XlaOp Computation( XlaOpKernelContext* ctx, const xla::XlaOp& lhs, - const gtl::ArraySlice& lhs_shape, const xla::XlaOp& rhs, - const gtl::ArraySlice& rhs_shape, const BCast& broadcast_helper, + const absl::Span& lhs_shape, const xla::XlaOp& rhs, + const absl::Span& rhs_shape, const BCast& broadcast_helper, const std::vector& extend_dimensions) = 0; void Compile(XlaOpKernelContext* ctx) override; @@ -67,8 +67,7 @@ class XlaBinaryOp : public XlaOpKernel { // 'broadcast_helper', yielding arguments 'lhs' and 'rhs' that have the same // shape. static std::pair Broadcast( - xla::XlaBuilder* builder, const xla::XlaOp& lhs, const xla::XlaOp& rhs, - const BCast& broadcast_helper); + xla::XlaOp lhs, xla::XlaOp rhs, const BCast& broadcast_helper); }; } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc b/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc index 12b0e38288e8f222ed506a75ec2575f27141c859..e96a1adce43c750314715107b4a1954d4a5b4e40 100644 --- a/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc @@ -48,7 +48,7 @@ class DepthToSpaceOp : public XlaOpKernel { OP_REQUIRES(ctx, kRequiredDims == input_rank, errors::InvalidArgument("Input rank should be ", kRequiredDims, "; got: ", input_rank)); - const gtl::InlinedVector input_shape = + const absl::InlinedVector input_shape = input_tensor_shape.dim_sizes(); xla::XlaOp input = ctx->Input(0); diff --git a/tensorflow/compiler/tf2xla/kernels/diag_op.cc b/tensorflow/compiler/tf2xla/kernels/diag_op.cc index ed44ad218b6dc073583ec339da082b6881ad672d..49c12fc232092873b69961644a059abc6035f64f 100644 --- a/tensorflow/compiler/tf2xla/kernels/diag_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/diag_op.cc @@ -29,7 +29,7 @@ namespace { // Create a diagonal / batch diagonal matrix with 'input' on the diagonal. xla::XlaOp CreateDiagonal(xla::XlaOp input, int64 last_dim_size, - gtl::ArraySlice other_dims, + absl::Span other_dims, xla::PrimitiveType element_type) { xla::XlaBuilder* builder = input.builder(); // Create two matrices that have the following forms, and compare them: @@ -177,8 +177,8 @@ class MatrixDiagOp : public XlaOpKernel { int last_dim = dims.size() - 1; int64 last_dim_size = input_shape.dim_size(last_dim); - tensorflow::gtl::ArraySlice other_dims(dims); - other_dims.pop_back(); + absl::Span other_dims(dims); + other_dims.remove_suffix(1); xla::XlaOp input = ctx->Input(0); xla::XlaOp diag = CreateDiagonal(input, last_dim_size, other_dims, diff --git a/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc b/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc index a3389d5b905bf3ee15744ab4fcee193d312e2ae0..4af1e8b44cbbd02d8e3ea5e42d841c92288b5d56 100644 --- a/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc @@ -34,15 +34,12 @@ class DynamicUpdateSliceOp : public XlaOpKernel { : XlaOpKernel(context) {} void Compile(XlaOpKernelContext* ctx) override { - VLOG(3) << "DynamicUpdateSliceOp::Compile"; + DataType index_type = ctx->InputType("indices"); + CHECK(index_type == DT_INT32 || index_type == DT_INT64); - DataType index_type = input_type(2); - OP_REQUIRES(ctx, index_type == DT_INT32 || index_type == DT_INT64, - errors::InvalidArgument("index must be int32 or int64")); - - const TensorShape input_shape = ctx->InputShape(0); - const TensorShape update_shape = ctx->InputShape(1); - const TensorShape index_shape = ctx->InputShape(2); + const TensorShape input_shape = ctx->InputShape("input"); + const TensorShape update_shape = ctx->InputShape("update"); + const TensorShape index_shape = ctx->InputShape("indices"); OP_REQUIRES( ctx, @@ -57,13 +54,56 @@ class DynamicUpdateSliceOp : public XlaOpKernel { input_shape.DebugString(), "; update shape is ", update_shape.DebugString())); - xla::XlaOp result = - xla::DynamicUpdateSlice(ctx->Input(0), ctx->Input(1), ctx->Input(2)); + xla::XlaOp result = xla::DynamicUpdateSlice( + ctx->Input("input"), ctx->Input("update"), ctx->Input("indices")); ctx->SetOutput(0, result); } }; REGISTER_XLA_OP(Name("XlaDynamicUpdateSlice"), DynamicUpdateSliceOp); +class DynamicSliceOp : public XlaOpKernel { + public: + explicit DynamicSliceOp(OpKernelConstruction* context) + : XlaOpKernel(context) {} + + void Compile(XlaOpKernelContext* ctx) override { + DataType index_type = ctx->InputType("start_indices"); + CHECK(index_type == DT_INT32 || index_type == DT_INT64); + CHECK(index_type == ctx->InputType("size_indices")); + + const TensorShape input_shape = ctx->InputShape("input"); + const TensorShape start_indices_shape = ctx->InputShape("start_indices"); + const TensorShape size_indices_shape = ctx->InputShape("size_indices"); + + OP_REQUIRES(ctx, + TensorShapeUtils::IsVector(start_indices_shape) && + start_indices_shape.num_elements() == input_shape.dims(), + errors::InvalidArgument( + "start_indices must be a vector with length equal to " + "input rank, but input rank is ", + input_shape.dims(), " and start_indices has shape ", + start_indices_shape.DebugString())); + OP_REQUIRES(ctx, + TensorShapeUtils::IsVector(size_indices_shape) && + size_indices_shape.num_elements() == input_shape.dims(), + errors::InvalidArgument( + "size_indices must be a vector with length equal to " + "input rank, but input rank is ", + input_shape.dims(), " and size_indices has shape ", + size_indices_shape.DebugString())); + + std::vector size_indices; + OP_REQUIRES_OK( + ctx, ctx->ConstantInputAsIntVector("size_indices", &size_indices)); + xla::XlaOp result = xla::DynamicSlice( + ctx->Input("input"), ctx->Input("start_indices"), size_indices); + ctx->SetOutput(0, result); + } +}; + +REGISTER_XLA_OP(Name("XlaDynamicSlice").CompileTimeConstInput("size_indices"), + DynamicSliceOp); + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/if_op.cc b/tensorflow/compiler/tf2xla/kernels/if_op.cc index 6e1dbf5472f0b1eb0abcbe29c553ae926ecf2d8a..56da50f140893c68c8a1556853884720b21c7229 100644 --- a/tensorflow/compiler/tf2xla/kernels/if_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/if_op.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/kernels/if_op.h" #include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/side_effect_util.h" #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" @@ -33,6 +34,11 @@ XlaIfOp::XlaIfOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("Tcond", &cond_type_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("Tin", &input_types_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("Tout", &output_types_)); + if (!ctx->GetAttr(kXlaTokenInputNodesAttrName, &token_input_nodes_).ok()) { + has_token_input_output_ = false; + } else { + has_token_input_output_ = !token_input_nodes_.empty(); + } } // TODO(b/35949885): There is duplication here with the handling of the @@ -90,6 +96,7 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { options.resolve_compile_time_constants = false; options.return_updated_values_for_all_resources = true; options.is_entry_computation = false; + options.add_token_input_output = has_token_input_output_; XlaCompiler* compiler = ctx->compiler(); XlaCompiler::CompilationResult then_result; @@ -191,7 +198,16 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { std::vector inputs(num_inputs); for (int i = 0; i < num_inputs; ++i) { int input_num = then_result.input_mapping[i] + 1; - if (ctx->input_type(input_num) == DT_RESOURCE) { + if (has_token_input_output_ && i == num_inputs - 1) { + // Set token input for this "if" op. + std::vector token_inputs; + for (const string& node_name : token_input_nodes_) { + auto token_or = compiler->GetNodeToken(node_name); + OP_REQUIRES_OK(ctx, token_or.status()); + token_inputs.push_back(token_or.ValueOrDie()); + } + inputs[i] = xla::AfterAll(b, token_inputs); + } else if (ctx->input_type(input_num) == DT_RESOURCE) { XlaResource* resource; OP_REQUIRES_OK(ctx, ctx->GetResourceInput(input_num, &resource)); OP_REQUIRES_OK(ctx, resource->Pack(&inputs[i], b)); @@ -219,6 +235,18 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { } ctx->SetOutput(i, output_handle); } + if (has_token_input_output_) { + // Set token output for this "if" op. + xla::XlaOp token_output = + xla::GetTupleElement(outputs, output_types_.size()); + auto shape_or = b->GetShape(token_output); + OP_REQUIRES_OK(ctx, shape_or.status()); + OP_REQUIRES(ctx, xla::ShapeUtil::IsToken(shape_or.ValueOrDie()), + errors::FailedPrecondition( + "Token output is not token type: ", + xla::ShapeUtil::HumanString(shape_or.ValueOrDie()))); + OP_REQUIRES_OK(ctx, compiler->SetNodeToken(name(), token_output)); + } // Updates the values of any resource variables modified by the conditional // bodies. diff --git a/tensorflow/compiler/tf2xla/kernels/if_op.h b/tensorflow/compiler/tf2xla/kernels/if_op.h index f9bc98a198a72dcc0594e61971713bf890ce30b6..7783e13a8a5dacc1901392703687230020f82483 100644 --- a/tensorflow/compiler/tf2xla/kernels/if_op.h +++ b/tensorflow/compiler/tf2xla/kernels/if_op.h @@ -52,6 +52,8 @@ class XlaIfOp : public XlaOpKernel { DataType cond_type_; DataTypeVector input_types_; DataTypeVector output_types_; + bool has_token_input_output_; + std::vector token_input_nodes_; }; } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/image_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_ops.cc index 33a73fe5fdf403e513be085dd7bcea3255277b4a..6713d6bc921b24b25baddfb3fd7296fffcc3d6ea 100644 --- a/tensorflow/compiler/tf2xla/kernels/image_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/image_ops.cc @@ -13,7 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h" #include "tensorflow/compiler/tf2xla/lib/util.h" +#include "tensorflow/compiler/tf2xla/lib/while_loop.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" @@ -316,6 +318,70 @@ class AdjustHueOp : public XlaOpKernel { }; REGISTER_XLA_OP(Name("AdjustHue"), AdjustHueOp); +struct WhileCondFn { + const int64 num_boxes; + const int64 output_size; + + explicit WhileCondFn(int64 num_boxes, int64 output_size) + : num_boxes(num_boxes), output_size(output_size) {} + + xla::StatusOr operator()(absl::Span values, + xla::XlaBuilder* cond_builder) const { + xla::XlaOp row_idx = values[0]; + xla::XlaOp row_in_bounds = + xla::Lt(row_idx, xla::ConstantR0(cond_builder, num_boxes)); + xla::XlaOp num_outputs_so_far = values[1]; + xla::XlaOp results_not_full = xla::Lt( + num_outputs_so_far, xla::ConstantR0(cond_builder, output_size)); + return xla::And(row_in_bounds, results_not_full); + } +}; + +// Process the boxes one-by-one using the iou matrix mask. +// This implementation uses a correct, but greedy, sequential algorithm +// to ensure that suppressed boxes cannot themselves suppress other +// boxes. +struct SuppressBodyFn { + const int64 num_boxes; + + explicit SuppressBodyFn(int64 num_boxes) : num_boxes(num_boxes) {} + + xla::StatusOr> operator()( + absl::Span values, xla::XlaBuilder* builder) const { + auto row_idx = values[0]; + auto num_outputs_so_far = values[1]; + auto iou_mask = values[2]; + auto included_iou = values[3]; + auto zero_r1 = xla::ConstantR1(builder, {0}); + // Determine if current elem is active using a slice. + auto row_idx_r1 = xla::Reshape(row_idx, {1}); + auto active_elem = xla::DynamicSlice(included_iou, row_idx_r1, {1}); + active_elem = xla::Reshape(active_elem, {}); + // Increment output count iff current elem is not suppressed. + num_outputs_so_far = xla::Select( + active_elem, num_outputs_so_far + xla::ConstantR0(builder, 1), + num_outputs_so_far); + // Slice out the row_idx. + auto starts = xla::ConcatInDim(builder, {row_idx_r1, zero_r1}, 0); + auto row_iou = xla::DynamicSlice(iou_mask, starts, {1, num_boxes}); + // Remove the diagonal from consideration. An elem cannot suppress + // itself. + auto update_starts = xla::ConcatInDim(builder, {zero_r1, row_idx_r1}, 0); + row_iou = xla::DynamicUpdateSlice( + row_iou, xla::ConstantR2FromArray2D(builder, {{false}}), + update_starts); + // Create a suppression by inverting polarity. + row_iou = xla::Reshape(row_iou, {num_boxes}); + auto supp_mask = xla::Not(row_iou); + // Update mask iff current elem is not suppressed. + included_iou = xla::Select(xla::Broadcast(active_elem, {num_boxes}), + xla::And(included_iou, supp_mask), included_iou); + row_idx = row_idx + xla::ConstantR0(builder, 1); + return std::vector{row_idx, num_outputs_so_far, iou_mask, + included_iou}; + } +}; + class NonMaxSuppressionOp : public XlaOpKernel { public: explicit NonMaxSuppressionOp(OpKernelConstruction* context) @@ -326,14 +392,12 @@ class NonMaxSuppressionOp : public XlaOpKernel { void Compile(XlaOpKernelContext* context) override { // TODO(b/111646731): Improve scalability of this op, using blocking. - int num_boxes_dim = 0; - int coords_dim = 1; const TensorShape& boxes_shape = context->InputShape("boxes"); OP_REQUIRES(context, TensorShapeUtils::IsMatrix(boxes_shape), errors::InvalidArgument("boxes must be 2-D, currently: ", boxes_shape.DebugString())); - const int64 num_boxes = boxes_shape.dim_size(num_boxes_dim); - OP_REQUIRES(context, boxes_shape.dim_size(coords_dim) == 4, + const int64 num_boxes = boxes_shape.dim_size(0); + OP_REQUIRES(context, boxes_shape.dim_size(1) == 4, errors::InvalidArgument("boxes must have 4 columns", boxes_shape.DebugString())); const TensorShape& scores_shape = context->InputShape("scores"); @@ -347,107 +411,148 @@ class NonMaxSuppressionOp : public XlaOpKernel { OP_REQUIRES(context, pad_to_max_output_size_, errors::InvalidArgument( "XLA compilation requires pad_to_max_output_size == True")); + OP_REQUIRES(context, num_boxes <= kint32max, + errors::InvalidArgument("XLA compilation requires number of " + "boxes to be <= kint32max, got ", + num_boxes)); - xla::XlaOp boxes = context->Input("boxes"); - xla::XlaOp scores = context->Input("scores"); + const xla::XlaOp boxes_input = context->Input("boxes"); + const xla::XlaOp scores_input = context->Input("scores"); int64 output_size; OP_REQUIRES_OK(context, context->ConstantInputAsIntScalar(2, &output_size)); OP_REQUIRES( context, output_size >= 0, errors::InvalidArgument("Need output_size >= 0, got ", output_size)); - xla::XlaOp score_thresh = context->Input("score_threshold"); - xla::XlaOp iou_thresh = context->Input("iou_threshold"); - + OP_REQUIRES(context, output_size <= kint32max, + errors::InvalidArgument("Need output_size <= kint32Max, got ", + output_size)); + const xla::XlaOp score_thresh = context->Input("score_threshold"); + const xla::XlaOp iou_thresh = context->Input("iou_threshold"); xla::XlaBuilder* const builder = context->builder(); // Choose a more convenient layout. - xla::XlaOp boxes_t = xla::Transpose(boxes, {1, 0}); - coords_dim = 0; - num_boxes_dim = 1; - - // Shapes are henceforth [1, num_boxes]. - xla::XlaOp coord_y0 = xla::SliceInDim(boxes_t, - /*start_index=*/0, - /*limit_index=*/1, - /*stride=*/1, - /*dimno=*/coords_dim); - xla::XlaOp coord_x0 = xla::SliceInDim(boxes_t, - /*start_index=*/1, - /*limit_index=*/2, - /*stride=*/1, - /*dimno=*/coords_dim); - xla::XlaOp coord_y1 = xla::SliceInDim(boxes_t, - /*start_index=*/2, - /*limit_index=*/3, - /*stride=*/1, - /*dimno=*/coords_dim); - xla::XlaOp coord_x1 = xla::SliceInDim(boxes_t, - /*start_index=*/3, - /*limit_index=*/4, - /*stride=*/1, - /*dimno=*/coords_dim); - xla::XlaOp y1 = - xla::Select(xla::Le(coord_y0, coord_y1), coord_y0, coord_y1); - xla::XlaOp y2 = - xla::Select(xla::Le(coord_y0, coord_y1), coord_y1, coord_y0); - xla::XlaOp x1 = - xla::Select(xla::Le(coord_x0, coord_x1), coord_x0, coord_x1); - xla::XlaOp x2 = - xla::Select(xla::Le(coord_x0, coord_x1), coord_x1, coord_x0); + const xla::XlaOp boxes = xla::Transpose(boxes_input, {1, 0}); + const xla::XlaOp boxes_sorted = xla::GetTupleElement( + xla::Sort(/*keys=*/-xla::Broadcast(scores_input, {4}), + /*values=*/{boxes}, + /*dimension=*/1), + 1); + // Track the mapping of indices into sorted domain. + const xla::XlaOp iota_indices = xla::Iota(builder, xla::S32, num_boxes); + const xla::XlaOp indices_sort = xla::Sort(-scores_input, {iota_indices}); + const xla::XlaOp indices_sorted = xla::GetTupleElement(indices_sort, 1); + const xla::XlaOp scores = xla::Neg(xla::GetTupleElement(indices_sort, 0)); + + // Shapes are henceforth [1, num_boxes]. 'c_y0' denotes 'coordinate' y0. + const xla::XlaOp c_y0 = xla::Reshape(xla::SliceInDim(boxes_sorted, + /*start_index=*/0, + /*limit_index=*/1, + /*stride=*/1, + /*dimno=*/0), + {num_boxes}); + const xla::XlaOp c_x0 = xla::Reshape(xla::SliceInDim(boxes_sorted, + /*start_index=*/1, + /*limit_index=*/2, + /*stride=*/1, + /*dimno=*/0), + {num_boxes}); + const xla::XlaOp c_y1 = xla::Reshape(xla::SliceInDim(boxes_sorted, + /*start_index=*/2, + /*limit_index=*/3, + /*stride=*/1, + /*dimno=*/0), + {num_boxes}); + const xla::XlaOp c_x1 = xla::Reshape(xla::SliceInDim(boxes_sorted, + /*start_index=*/3, + /*limit_index=*/4, + /*stride=*/1, + /*dimno=*/0), + {num_boxes}); + + xla::XlaOp y1 = xla::Select(xla::Le(c_y0, c_y1), c_y0, c_y1); + xla::XlaOp y2 = xla::Select(xla::Le(c_y0, c_y1), c_y1, c_y0); + xla::XlaOp x1 = xla::Select(xla::Le(c_x0, c_x1), c_x0, c_x1); + xla::XlaOp x2 = xla::Select(xla::Le(c_x0, c_x1), c_x1, c_x0); xla::XlaOp area = (y2 - y1) * (x2 - x1); - // Transpose the 1xN tensors, instead of the NxN tensors. - xla::XlaOp y1_t = xla::Transpose(y1, {1, 0}); - xla::XlaOp y2_t = xla::Transpose(y2, {1, 0}); - xla::XlaOp x1_t = xla::Transpose(x1, {1, 0}); - xla::XlaOp x2_t = xla::Transpose(x2, {1, 0}); - xla::XlaOp area_t = xla::Transpose(area, {1, 0}); + // Shapes are henceforth [1, num_boxes]. + y1 = xla::Broadcast(y1, {1}); + y2 = xla::Broadcast(y2, {1}); + x1 = xla::Broadcast(x1, {1}); + x2 = xla::Broadcast(x2, {1}); + area = xla::Broadcast(area, {1}); // Shapes are henceforth [num_boxes, num_boxes]. - xla::XlaOp i_xmin = xla::Max(x1, x1_t); - xla::XlaOp i_ymin = xla::Max(y1, y1_t); - xla::XlaOp i_xmax = xla::Min(x2, x2_t); - xla::XlaOp i_ymax = xla::Min(y2, y2_t); + xla::XlaOp i_xmin = xla::Max(x1, xla::Transpose(x1, {1, 0})); + xla::XlaOp i_ymin = xla::Max(y1, xla::Transpose(y1, {1, 0})); + xla::XlaOp i_xmax = xla::Min(x2, xla::Transpose(x2, {1, 0})); + xla::XlaOp i_ymax = xla::Min(y2, xla::Transpose(y2, {1, 0})); auto square_zero = xla::ZerosLike(i_xmin); xla::XlaOp i_area = xla::Max(i_xmax - i_xmin, square_zero) * xla::Max(i_ymax - i_ymin, square_zero); - xla::XlaOp u_area = area + area_t - i_area; + xla::XlaOp u_area = area + xla::Transpose(area, {1, 0}) - i_area; xla::XlaOp iou = i_area / u_area; xla::XlaOp iou_thresh_mask = xla::Gt(iou, iou_thresh + square_zero); - xla::XlaOp scores_2d = xla::Reshape(scores, {num_boxes, 1}); - xla::XlaOp score_cmp_mask = - xla::Gt(scores_2d, xla::Transpose(scores_2d, {1, 0})); - xla::XlaOp suppress = xla::And(iou_thresh_mask, score_cmp_mask); - - // Shapes are [num_boxes] after the reduce. - xla::XlaOp included_iou = xla::Not(xla::Reduce( - suppress, - /*init_value=*/xla::ConstantR0(builder, false), - /*computation=*/CreateScalarOrComputation(xla::PRED, builder), - /*dimensions_to_reduce=*/{0})); + xla::XlaOp included_iou = + xla::Broadcast(xla::ConstantR0(builder, true), {num_boxes}); + + std::vector init_values; + init_values.reserve(4); + init_values.push_back(xla::ConstantR0(builder, 0)); // col_idx + init_values.push_back(xla::ConstantR0(builder, 0)); // num_outputs + init_values.push_back(iou_thresh_mask); + init_values.push_back(included_iou); + + auto suppress_loop_result = + XlaWhileLoop(WhileCondFn(num_boxes, output_size), + SuppressBodyFn(num_boxes), init_values, "suppress_loop", + builder) + .ValueOrDie(); + xla::XlaOp included_score = xla::Gt(scores, xla::Broadcast(score_thresh, {num_boxes})); - xla::XlaOp included = xla::And(included_iou, included_score); + xla::XlaOp included = xla::And(included_score, suppress_loop_result[3]); + + // Only consider boxes over which we have iterated. This allows for accurate + // counting. DynamicSlice would require knowledge of the size of the output. + auto valid_elem = xla::Lt( + iota_indices, xla::Broadcast(suppress_loop_result[0], {num_boxes})); + included = xla::And(included, valid_elem); + xla::XlaOp neg_inf = xla::Broadcast(xla::MinValue(builder, xla::F32), {num_boxes}); xla::XlaOp scores_included = xla::Select(included, scores, neg_inf); - + xla::XlaOp output_tuple = TopK(scores_included, output_size); + xla::XlaOp selected_indices_sorted = xla::GetTupleElement(output_tuple, 1); + // Calculate num_valid. + // Note: num_valid cannot be taken from the loop outputs, because outputs + // can be suppressed by score threshold. xla::XlaOp ones_included = xla::Select( included, xla::Broadcast(xla::ConstantR0(builder, 1), {num_boxes}), xla::Broadcast(xla::ConstantR0(builder, 0), {num_boxes})); - - // num_valid is scalar. - xla::XlaOp num_valid = xla::Reduce( + // num_valid is scalar. Value should be bound by output_size. + xla::XlaOp num_valid_total = xla::Reduce( ones_included, /*init_value=*/xla::ConstantR0(builder, 0), /*computation=*/CreateScalarAddComputation(xla::S32, builder), /*dimensions_to_reduce=*/{0}); - - xla::XlaOp output_tuple = TopK(scores_included, output_size); - xla::XlaOp selected_indices = xla::GetTupleElement(output_tuple, 1); + xla::XlaOp num_valid = + xla::Min(num_valid_total, xla::ConstantR0(builder, output_size)); + + // Re-index into the original scores input tensor, using a Gather. + // Boxes were suppressed in the sorted domain. + xla::XlaOp selected_indices; + DataType gather_type = context->expected_output_dtype(0); + OP_REQUIRES_OK( + context, + XlaGather(indices_sorted, scores_shape, selected_indices_sorted, + TensorShape({output_size}), + /*axis=*/0, + /*indices_are_nd=*/false, + /*dtype=*/gather_type, DT_INT32, builder, &selected_indices)); context->SetOutput(0, selected_indices); context->SetOutput(1, num_valid); diff --git a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc index 8e071bf0b7ae638888818ea8cd5d63b5d543342e..7b2bb4a7c50fc954237e09a32f71009f790b60d0 100644 --- a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/array4d.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/lib/numeric.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" @@ -78,7 +79,7 @@ struct ResizeConvolutionDims { std::vector stride; }; ResizeConvolutionDims ComputeResizeConvolutionParameters( - gtl::ArraySlice in_size, gtl::ArraySlice out_size, + absl::Span in_size, absl::Span out_size, bool align_corners) { CHECK_EQ(in_size.size(), out_size.size()); int num_spatial_dims = in_size.size(); @@ -132,14 +133,14 @@ int64 CalculateUpperPadding(int64 in_size, int64 out_size, int64 kernel_size, // If the 2D kernel would be very large, the 1D kernel can be applied once in // each dimension due to the symmetry of the kernel along all axis to reduce the // computational intensity. -std::vector Make1DKernel(int64 n) { +xla::XlaOp Make1DKernel(xla::XlaBuilder* builder, int64 n) { std::vector kernel(n * 2 - 1); for (int64 i = 0; i < n; ++i) { float v = (i + 1.0f) / n; kernel[i] = v; kernel[n * 2 - 2 - i] = v; } - return kernel; + return xla::ConstantR1(builder, kernel); } // Kernels with more than 16 spatial elements are considered intense and the @@ -147,43 +148,28 @@ std::vector Make1DKernel(int64 n) { const int64 kMax2DKernelSize = 16; xla::XlaOp MakeBilinearResizeKernel(xla::XlaBuilder* builder, - gtl::ArraySlice kernel_size, + absl::Span kernel_size, int64 channels) { - xla::XlaOp channels_iota = xla::Iota(builder, xla::S32, channels); + auto depthwise_kernel = xla::Broadcast( + xla::Zero(builder, xla::F32), + {(2 * kernel_size[0] - 1), (2 * kernel_size[1] - 1), channels, 1}); - auto diag = xla::ConvertElementType( - xla::Eq(xla::Broadcast(channels_iota, {2 * kernel_size[0] - 1, - 2 * kernel_size[1] - 1, channels}), - channels_iota, /*broadcast_dimensions=*/{2}), - xla::PrimitiveType::F32); return xla::Mul( - xla::Mul(diag, - xla::ConstantR1(builder, Make1DKernel(kernel_size[1])), + xla::Add(depthwise_kernel, Make1DKernel(builder, kernel_size[1]), /*broadcast_dimensions=*/{1}), - xla::ConstantR1(builder, Make1DKernel(kernel_size[0])), + Make1DKernel(builder, kernel_size[0]), /*broadcast_dimensions=*/{0}); } xla::XlaOp MakeBilinearResizeKernelInDim(xla::XlaBuilder* builder, - gtl::ArraySlice kernel_size, + absl::Span kernel_size, int64 channels, int64 dim) { - xla::XlaOp channels_iota = xla::Iota(builder, xla::S32, channels); - - auto diag = xla::ConvertElementType( - xla::Eq( - xla::Broadcast(channels_iota, - {dim == 0 ? (2 * kernel_size[0] - 1) : 1, - dim == 1 ? (2 * kernel_size[1] - 1) : 1, channels}), - channels_iota, /*broadcast_dimensions=*/{2}), - xla::PrimitiveType::F32); - if (dim == 1) { - return xla::Mul( - diag, xla::ConstantR1(builder, Make1DKernel(kernel_size[1])), - /*broadcast_dimensions=*/{1}); - } - return xla::Mul(diag, - xla::ConstantR1(builder, Make1DKernel(kernel_size[0])), - /*broadcast_dimensions=*/{0}); + auto depthwise_kernel = + xla::Broadcast(xla::Zero(builder, xla::F32), + {dim == 0 ? (2 * kernel_size[0] - 1) : 1, + dim == 1 ? (2 * kernel_size[1] - 1) : 1, channels, 1}); + return xla::Add(depthwise_kernel, Make1DKernel(builder, kernel_size[dim]), + /*broadcast_dimensions=*/{dim}); } xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder, @@ -206,8 +192,8 @@ xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder, xla::ConvolutionDimensionNumbers dimension_numbers; dimension_numbers.set_input_batch_dimension(0); dimension_numbers.set_output_batch_dimension(0); - dimension_numbers.set_input_feature_dimension(3); - dimension_numbers.set_output_feature_dimension(3); + dimension_numbers.set_input_feature_dimension(num_spatial_dims + 1); + dimension_numbers.set_output_feature_dimension(num_spatial_dims + 1); for (int i = 0; i < num_spatial_dims; ++i) { dimension_numbers.add_input_spatial_dimensions(1 + i); dimension_numbers.add_output_spatial_dimensions(1 + i); @@ -285,7 +271,8 @@ xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder, {{dims.kernel_size[0] - 1, upper_padding[0]}, {dims.kernel_size[1] - 1, upper_padding[1]}}, /*lhs_dilation=*/dims.kernel_size, - /*rhs_dilation=*/{1, 1}, dimension_numbers); + /*rhs_dilation=*/{1, 1}, dimension_numbers, + /*feature_group_count=*/channels); } else { xla::XlaOp kernel0 = MakeBilinearResizeKernelInDim(builder, dims.kernel_size, channels, 0); @@ -294,7 +281,8 @@ xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder, /*padding=*/ {{dims.kernel_size[0] - 1, upper_padding[0]}, {0, 0}}, /*lhs_dilation=*/{dims.kernel_size[0], 1}, - /*rhs_dilation=*/{1, 1}, dimension_numbers); + /*rhs_dilation=*/{1, 1}, dimension_numbers, + /*feature_group_count=*/channels); xla::XlaOp kernel1 = MakeBilinearResizeKernelInDim(builder, dims.kernel_size, channels, 1); output = xla::ConvGeneralDilated( @@ -302,7 +290,8 @@ xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder, /*padding=*/ {{0, 0}, {dims.kernel_size[1] - 1, upper_padding[1]}}, /*lhs_dilation=*/{1, dims.kernel_size[1]}, - /*rhs_dilation=*/{1, 1}, dimension_numbers); + /*rhs_dilation=*/{1, 1}, dimension_numbers, + /*feature_group_count=*/channels); } // Add broadcasts to handle expanding from a size == 1 dimension to a @@ -331,15 +320,15 @@ xla::XlaOp ResizeUsingDilationAndConvolutionGradOp(xla::XlaBuilder* builder, xla::ConvolutionDimensionNumbers dimension_numbers; dimension_numbers.set_input_batch_dimension(0); dimension_numbers.set_output_batch_dimension(0); - dimension_numbers.set_input_feature_dimension(3); - dimension_numbers.set_output_feature_dimension(3); + dimension_numbers.set_input_feature_dimension(num_spatial_dims + 1); + dimension_numbers.set_output_feature_dimension(num_spatial_dims + 1); for (int i = 0; i < num_spatial_dims; ++i) { - dimension_numbers.add_input_spatial_dimensions(1 + i); - dimension_numbers.add_output_spatial_dimensions(1 + i); + dimension_numbers.add_input_spatial_dimensions(i + 1); + dimension_numbers.add_output_spatial_dimensions(i + 1); dimension_numbers.add_kernel_spatial_dimensions(i); } - dimension_numbers.set_kernel_input_feature_dimension(num_spatial_dims); - dimension_numbers.set_kernel_output_feature_dimension(num_spatial_dims + 1); + dimension_numbers.set_kernel_input_feature_dimension(num_spatial_dims + 1); + dimension_numbers.set_kernel_output_feature_dimension(num_spatial_dims); xla::XlaOp output; if (dims.kernel_size[0] * dims.kernel_size[1] < kMax2DKernelSize) { xla::XlaOp kernel = @@ -362,7 +351,8 @@ xla::XlaOp ResizeUsingDilationAndConvolutionGradOp(xla::XlaBuilder* builder, {{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1}, {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}}, /*lhs_dilation=*/dims.stride, - /*rhs_dilation=*/{1, 1}, dimension_numbers); + /*rhs_dilation=*/{1, 1}, dimension_numbers, + /*feature_group_count=*/channels); } else { xla::XlaOp kernel0 = MakeBilinearResizeKernelInDim(builder, dims.kernel_size, channels, 0); @@ -388,14 +378,16 @@ xla::XlaOp ResizeUsingDilationAndConvolutionGradOp(xla::XlaBuilder* builder, /*padding=*/ {{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1}, {0, 0}}, /*lhs_dilation=*/{dims.stride[0], 1}, - /*rhs_dilation=*/{1, 1}, dimension_numbers); + /*rhs_dilation=*/{1, 1}, dimension_numbers, + /*feature_group_count=*/channels); output = xla::ConvGeneralDilated( output, kernel1, /*window_strides=*/{1, dims.kernel_size[1]}, /*padding=*/ {{0, 0}, {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}}, /*lhs_dilation=*/{1, dims.stride[1]}, - /*rhs_dilation=*/{1, 1}, dimension_numbers); + /*rhs_dilation=*/{1, 1}, dimension_numbers, + /*feature_group_count=*/channels); } // If in_size[i] > 1 and grad_size[i] == 1, pad the output in dimension i. diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc index 22a45b2a11e8ecb688f8e773ef4b286eafe68f4f..f210bfbd886e48b8d7972393ed1899491486646c 100644 --- a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc +++ b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc @@ -78,30 +78,40 @@ class ArgMaxCustomCallOp : public XlaOpKernel { std::vector args; args.push_back(ctx->Input(0)); args.push_back(xla::ConstantLiteral( - &b, *xla::LiteralUtil::CreateR1(input_shape.dim_sizes()))); + &b, xla::LiteralUtil::CreateR1(input_shape.dim_sizes()))); if (input_shape.dims() > 1) { // Don't bother passing the output shape and dim for the 1d case, since // the shape is always a scalar and the dim is always 0. args.push_back(xla::ConstantLiteral( - &b, *xla::LiteralUtil::CreateR1(output_shape.dim_sizes()))); + &b, xla::LiteralUtil::CreateR1(output_shape.dim_sizes()))); args.push_back( - xla::ConstantLiteral(&b, *xla::LiteralUtil::CreateR0(dim))); + xla::ConstantLiteral(&b, xla::LiteralUtil::CreateR0(dim))); } - xla::Shape xla_shape = - xla::ShapeUtil::MakeShape(xla::S64, output_shape.dim_sizes()); + // The argmax function expects row-major layout. + xla::Shape xla_shape = xla::ShapeUtil::MakeShapeWithDescendingLayout( + xla::S64, output_shape.dim_sizes()); + std::vector arg_shapes; + for (const xla::XlaOp& arg : args) { + auto shape_status = b.GetShape(arg); + OP_REQUIRES_OK(ctx, shape_status.status()); + xla::Shape arg_shape = shape_status.ConsumeValueOrDie(); + *arg_shape.mutable_layout() = xla::LayoutUtil::MakeDescendingLayout( + xla::ShapeUtil::Rank(arg_shape)); + arg_shapes.push_back(std::move(arg_shape)); + } // Tell XLA to call the custom code, defined in // index_ops_kernel_argmax_float_1d.cc. xla::XlaOp output; switch (input_shape.dims()) { case 1: - output = - xla::CustomCall(&b, "argmax_float_1d_xla_impl", args, xla_shape); + output = xla::CustomCallWithLayout(&b, "argmax_float_1d_xla_impl", args, + xla_shape, arg_shapes); break; case 2: - output = - xla::CustomCall(&b, "argmax_float_2d_xla_impl", args, xla_shape); + output = xla::CustomCallWithLayout(&b, "argmax_float_2d_xla_impl", args, + xla_shape, arg_shapes); break; default: OP_REQUIRES(ctx, false, diff --git a/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc b/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc index eedfc3c9140d7b1ccc1944611de98c1d49fbdaf2..2a42eeaf76ab3aa88ff3a93ef7eb7ab217964bb6 100644 --- a/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc @@ -29,7 +29,14 @@ class MirrorPadOp : public XlaOpKernel { xla::StatusOr DoMirrorPad(const xla::XlaOp& t, const xla::Shape& original_shape, const xla::LiteralSlice& pad_literal, + const MirrorPadMode mode, xla::XlaBuilder* b) { + // The difference in the semantics of REFLECT and SYMMETRIC is that REFLECT + // will not mirror the border values while symmetric does. + // e.g. input is [1, 2, 3] and paddings is [0, 2], then the output is: + // - [1, 2, 3, 2, 1] in reflect mode + // - [1, 2, 3, 3, 2] in symmetric mode. + int64 excluded_edges = mode == MirrorPadMode::REFLECT ? 1 : 0; xla::XlaOp accum = t; for (int64 dimno = xla::ShapeUtil::Rank(original_shape) - 1; dimno >= 0; --dimno) { @@ -39,9 +46,19 @@ class MirrorPadOp : public XlaOpKernel { TF_ASSIGN_OR_RETURN(int64 rhs_padding, pad_literal.GetIntegralAsS64({dimno, 1})); int64 dim_size = original_shape.dimensions(dimno); - auto lhs_pad = xla::SliceInDim(t_rev, dim_size - 1 - lhs_padding, - dim_size - 1, 1, dimno); - auto rhs_pad = xla::SliceInDim(t_rev, 1, 1 + rhs_padding, 1, dimno); + + // Padding amounts on each side must be no more than the size of the + // original shape. + TF_RET_CHECK(lhs_padding >= 0 && + lhs_padding <= dim_size - excluded_edges); + TF_RET_CHECK(rhs_padding >= 0 && + rhs_padding <= dim_size - excluded_edges); + + auto lhs_pad = + xla::SliceInDim(t_rev, dim_size - excluded_edges - lhs_padding, + dim_size - excluded_edges, 1, dimno); + auto rhs_pad = xla::SliceInDim(t_rev, excluded_edges, + excluded_edges + rhs_padding, 1, dimno); accum = xla::ConcatInDim(b, {lhs_pad, accum, rhs_pad}, dimno); } return accum; @@ -53,9 +70,10 @@ class MirrorPadOp : public XlaOpKernel { MirrorPadMode mode; OP_REQUIRES_OK(ctx, GetNodeAttr(def(), "mode", &mode)); - OP_REQUIRES(ctx, mode == MirrorPadMode::REFLECT, - xla::Unimplemented( - "Only REFLECT MirrorPad mode is currently supported")); + OP_REQUIRES( + ctx, mode == MirrorPadMode::REFLECT || mode == MirrorPadMode::SYMMETRIC, + xla::Unimplemented("Unsupported MirrorPad mode. Only SYMMETRIC and " + "REFLECT modes are currently supported")); const int dims = input_shape.dims(); OP_REQUIRES( @@ -83,7 +101,7 @@ class MirrorPadOp : public XlaOpKernel { xla::StatusOr in0_shape = b->GetShape(in0); OP_REQUIRES(ctx, in0_shape.ok(), in0_shape.status()); xla::StatusOr accum_status = - DoMirrorPad(in0, in0_shape.ValueOrDie(), pad_literal, b); + DoMirrorPad(in0, in0_shape.ValueOrDie(), pad_literal, mode, b); OP_REQUIRES_OK(ctx, accum_status.status()); diff --git a/tensorflow/compiler/tf2xla/kernels/permute_op.cc b/tensorflow/compiler/tf2xla/kernels/permute_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..3ca5eecf1a811aca9ad9201ba285d2112db7533e --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/permute_op.cc @@ -0,0 +1,98 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/util/tensor_format.h" + +namespace tensorflow { +namespace { + +class DataFormatVecPermuteOp : public XlaOpKernel { + public: + explicit DataFormatVecPermuteOp(OpKernelConstruction* ctx) + : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("src_format", &src_format_)); + OP_REQUIRES( + ctx, src_format_.size() == 4, + errors::InvalidArgument("Data format should have 4 characters")); + TensorFormat data_format; + OP_REQUIRES(ctx, FormatFromString(src_format_, &data_format), + errors::InvalidArgument("Invalid data format")); + OP_REQUIRES_OK(ctx, ctx->GetAttr("dst_format", &dst_format_)); + OP_REQUIRES( + ctx, dst_format_.size() == 4, + errors::InvalidArgument("Data format should have 4 characters")); + OP_REQUIRES(ctx, FormatFromString(dst_format_, &data_format), + errors::InvalidArgument("Invalid data format")); + } + void Compile(XlaOpKernelContext* ctx) override { + auto builder = ctx->builder(); + const TensorShape input_tensor_shape = ctx->InputShape(0); + int input_rank = input_tensor_shape.dims(); + OP_REQUIRES(ctx, input_rank == 1 || input_rank == 2, + errors::InvalidArgument( + "Input must be a vector or matrix, but got shape ", + input_tensor_shape.DebugString())); + OP_REQUIRES( + ctx, input_tensor_shape.dim_size(0) == 4, + errors::InvalidArgument( + "First dimension of input must be of size 4, but got shape ", + input_tensor_shape.DebugString())); + if (input_rank == 2) { + OP_REQUIRES( + ctx, input_tensor_shape.dim_size(1) == 2, + errors::InvalidArgument( + "Second dimension of 2D input must be of size 2, but got shape ", + input_tensor_shape.DebugString())); + } + std::vector dst_indices(4, 0); + for (int i = 0; i < 4; ++i) { + for (int j = 0; j < 4; ++j) { + if (src_format_[i] == dst_format_[j]) { + dst_indices[i] = j; + break; + } + } + } + auto keys = xla::ConstantR1(builder, absl::Span(dst_indices)); + if (input_rank == 2) { + keys = xla::BroadcastInDim( + keys, xla::ShapeUtil::MakeShape(xla::S32, {4, 2}), {0}); + } + auto sorted = xla::Sort(keys, {ctx->Input(0)}, 0); + auto output = xla::GetTupleElement(sorted, 1); + ctx->SetOutput(0, output); + } + + private: + string src_format_; + string dst_format_; + + TF_DISALLOW_COPY_AND_ASSIGN(DataFormatVecPermuteOp); +}; + +// TODO(b/115384656): Support DT_INT64. +REGISTER_XLA_OP(Name("DataFormatVecPermute").TypeConstraint("T", DT_INT32), + DataFormatVecPermuteOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc index f6f158a73be42ea2602811ad64a2a2c655dab088..27690c156e4da129ad139f3880bba3a208b5606d 100644 --- a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc @@ -138,7 +138,7 @@ xla::TensorFormat XlaTensorFormat(tensorflow::TensorFormat data_format, int num_dims = num_spatial_dims + 2; int batch_dimension = GetTensorBatchDimIndex(num_dims, data_format); int feature_dimension = GetTensorFeatureDimIndex(num_dims, data_format); - gtl::InlinedVector spatial_dimensions(num_spatial_dims); + absl::InlinedVector spatial_dimensions(num_spatial_dims); for (int spatial_dim = 0; spatial_dim < num_spatial_dims; ++spatial_dim) { spatial_dimensions[spatial_dim] = GetTensorSpatialDimIndex(num_dims, data_format, spatial_dim); diff --git a/tensorflow/compiler/tf2xla/kernels/qr_op.cc b/tensorflow/compiler/tf2xla/kernels/qr_op.cc index de9068a640dc03b141b6954eaa1629dd6c8c1f3a..7ea0afc1f53cbe4cfcc3f6121a4ecd55864c1b52 100644 --- a/tensorflow/compiler/tf2xla/kernels/qr_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/qr_op.cc @@ -23,15 +23,10 @@ namespace { class QROp : public XlaOpKernel { public: explicit QROp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { - bool full_matrices; - OP_REQUIRES_OK(ctx, ctx->GetAttr("full_matrices", &full_matrices)); - OP_REQUIRES( - ctx, full_matrices, - errors::Unimplemented("full_matrices=False case of QR decomposition is " - "not implemented in TF/XLA")); + OP_REQUIRES_OK(ctx, ctx->GetAttr("full_matrices", &full_matrices_)); } void Compile(XlaOpKernelContext* ctx) override { - auto result = QRDecomposition(ctx->Input(0)); + auto result = QRDecomposition(ctx->Input(0), full_matrices_); if (!result.ok()) { ctx->SetStatus(result.status()); return; @@ -39,6 +34,11 @@ class QROp : public XlaOpKernel { ctx->SetOutput(0, result.ValueOrDie().q); ctx->SetOutput(1, result.ValueOrDie().r); } + + private: + // If true, compute full-sized q and r. If false, compute only the leading P + // columns of q. + bool full_matrices_; }; REGISTER_XLA_OP(Name("Qr").TypeConstraint("T", kFloatTypes), QROp); diff --git a/tensorflow/compiler/tf2xla/kernels/random_ops.cc b/tensorflow/compiler/tf2xla/kernels/random_ops.cc index 2da9340625db08b14b78340c471f096baf15689d..7ef6fa305b7f5b5aae187808f856a9273f101e14 100644 --- a/tensorflow/compiler/tf2xla/kernels/random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/random_ops.cc @@ -135,7 +135,7 @@ class RandomShuffleOp : public XlaOpKernel { xla::XlaOp curr = input; for (int i = 0; i < rounds; ++i) { xla::XlaOp keys = xla::RngUniform(zero, max_value, key_shape); - xla::XlaOp sorted = xla::Sort(keys, curr); + xla::XlaOp sorted = xla::Sort(keys, {curr}); curr = xla::GetTupleElement(sorted, 1); } @@ -155,7 +155,8 @@ class RandomShuffleOp : public XlaOpKernel { xla::XlaOp indices = xla::Iota(builder, xla::S32, n); // Swap the indices at i and swaps[i]. - auto swap_body_fn = [&](xla::XlaOp i, gtl::ArraySlice loop_vars, + auto swap_body_fn = [&](xla::XlaOp i, + absl::Span loop_vars, xla::XlaBuilder* builder) -> xla::StatusOr> { auto swaps = loop_vars[0]; diff --git a/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc b/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc index 8102faad28db71075fb8da269c55edbdb667193e..8eee5b12991fb377203d780cecd8916952bd699a 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc @@ -40,10 +40,16 @@ class ReduceWindowOp : public XlaOpKernel { std::vector window_dimensions; std::vector window_strides; + std::vector base_dilations; + std::vector window_dilations; OP_REQUIRES_OK(context, context->ConstantInputAsIntVector( "window_dimensions", &window_dimensions)); OP_REQUIRES_OK(context, context->ConstantInputAsIntVector("window_strides", &window_strides)); + OP_REQUIRES_OK(context, context->ConstantInputAsIntVector("base_dilations", + &base_dilations)); + OP_REQUIRES_OK(context, context->ConstantInputAsIntVector( + "window_dilations", &window_dilations)); const int rank = input_shape.dims(); OP_REQUIRES(context, rank == window_dimensions.size(), @@ -56,6 +62,16 @@ class ReduceWindowOp : public XlaOpKernel { "The size of window_strides must be equal to the input " "rank (", window_strides.size(), " vs. ", rank, ")")); + OP_REQUIRES(context, rank == base_dilations.size(), + errors::InvalidArgument( + "The size of base_dilations must be equal to the input " + "rank (", + base_dilations.size(), " vs. ", rank, ")")); + OP_REQUIRES(context, rank == window_dilations.size(), + errors::InvalidArgument( + "The size of window_dilations must be equal to the input " + "rank (", + window_dilations.size(), " vs. ", rank, ")")); // Build the reducer function. XlaCompiler::Argument reducer_arg; @@ -102,7 +118,8 @@ class ReduceWindowOp : public XlaOpKernel { xla::XlaOp output = xla::ReduceWindowWithGeneralPadding( context->Input(0), context->Input(1), *reducer.computation, - window_dimensions, window_strides, padding); + window_dimensions, window_strides, base_dilations, window_dilations, + padding); context->SetOutput(0, output); } @@ -115,6 +132,8 @@ class ReduceWindowOp : public XlaOpKernel { REGISTER_XLA_OP(Name("XlaReduceWindow") .CompileTimeConstInput("window_dimensions") .CompileTimeConstInput("window_strides") + .CompileTimeConstInput("base_dilations") + .CompileTimeConstInput("window_dilations") .CompileTimeConstInput("padding"), ReduceWindowOp); diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc index 598248563bb93146e6dea3016822d26b8bf368e7..118f2798d559f43acb7f6394a7337426164325ef 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc @@ -69,7 +69,7 @@ void XlaReductionOp::Compile(XlaOpKernelContext* ctx) { VLOG(1) << "data shape: " << data_shape.DebugString(); VLOG(1) << "axes : " << absl::StrJoin(axes, ","); - gtl::InlinedVector bitmap(data_shape.dims(), false); + absl::InlinedVector bitmap(data_shape.dims(), false); std::vector xla_axes; int64 num_elements_reduced = 1LL; for (int64 i = 0; i < axes_tensor_shape.num_elements(); ++i) { @@ -103,7 +103,7 @@ void XlaReductionOp::Compile(XlaOpKernelContext* ctx) { xla::XlaBuilder* const b = ctx->builder(); // Construct the builder for the reduction lambda. - xla::XlaBuilder r(strings::StrCat(desc, "-reduction")); + xla::XlaBuilder r(absl::StrCat(desc, "-reduction")); xla::PrimitiveType type; TF_CHECK_OK(DataTypeToPrimitiveType(reduction_type_, &type)); diff --git a/tensorflow/compiler/tf2xla/kernels/retval_op.cc b/tensorflow/compiler/tf2xla/kernels/retval_op.cc index 64900e4709fd3e16d21096b0cfff8922906cb0d4..e172c649325adb6f7761ce0be141f21e8d545bc1 100644 --- a/tensorflow/compiler/tf2xla/kernels/retval_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/retval_op.cc @@ -48,6 +48,15 @@ class RetvalOp : public XlaOpKernel { } else { xla::XlaOp input = ctx->Input(0); const TensorShape input_shape = ctx->InputShape(0); + DataType input_type = ctx->input_type(0); + XlaContext& tc = XlaContext::Get(ctx); + + if (input_type == DT_RESOURCE) { + XlaResource* resource; + OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource)); + ctx->SetStatus(tc.AddResourceRetval(index_, resource)); + return; + } auto is_constant = ctx->builder()->IsConstant(input); if (!is_constant.ok()) { @@ -55,7 +64,6 @@ class RetvalOp : public XlaOpKernel { return; } - XlaContext& tc = XlaContext::Get(ctx); if (tc.resolve_compile_time_constants() && (input_shape.num_elements() == 0 || is_constant.ValueOrDie())) { xla::Literal literal; @@ -104,7 +112,8 @@ class RetvalOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(RetvalOp); }; -REGISTER_XLA_OP(Name("_Retval").CompilationOnly(), RetvalOp); +REGISTER_XLA_OP(Name("_Retval").AllowResourceTypes().CompilationOnly(), + RetvalOp); } // anonymous namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/reverse_op.cc b/tensorflow/compiler/tf2xla/kernels/reverse_op.cc index c0afccaa5b15dd33fcd016dfdd9bb18e244bf90a..8494864b33a44b03a07e3fea7766285f54074e7d 100644 --- a/tensorflow/compiler/tf2xla/kernels/reverse_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reverse_op.cc @@ -97,7 +97,7 @@ class ReverseV2Op : public XlaOpKernel { // witnessed_axes is used to ensure that the same axis is not marked to be // reversed multiple times. - gtl::InlinedVector witnessed_axes(x_shape.dims(), false); + absl::InlinedVector witnessed_axes(x_shape.dims(), false); for (int d = 0; d < axes.size(); ++d) { OP_REQUIRES( diff --git a/tensorflow/compiler/tf2xla/kernels/scan_ops.cc b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc index ab094d7dd1ce9856a3c2854fd2776827d6c4b76f..57afd608de820573821d605cadcc8779474b5fd6 100644 --- a/tensorflow/compiler/tf2xla/kernels/scan_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc @@ -104,7 +104,8 @@ class ScanOp : public XlaOpKernel { } auto output = xla::ReduceWindowWithGeneralPadding( XlaHelpers::ConvertElementType(builder, ctx->Input(0), dtype), init, - *reducer, window_dims, window_strides, padding); + *reducer, window_dims, window_strides, + /*base_dilations=*/{}, /*window_dilations=*/{}, padding); output = XlaHelpers::ConvertElementType(builder, output, ctx->input_type(0)); diff --git a/tensorflow/compiler/tf2xla/kernels/select_op.cc b/tensorflow/compiler/tf2xla/kernels/select_op.cc index 6ce50efb4aa6e3434a7c6009cf9f52f6cff9cc9f..9e4c57c9bf73369662274f6b783418e18ff860c2 100644 --- a/tensorflow/compiler/tf2xla/kernels/select_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/select_op.cc @@ -66,8 +66,8 @@ class SelectOp : public XlaOpKernel { // XLA. It seems we have to broadcast on the left and then Reshape // to get the dimensions in the right order. const auto dim_sizes = then_shape.dim_sizes(); - gtl::ArraySlice bdims = dim_sizes; - bdims.pop_front(); + absl::Span bdims = dim_sizes; + bdims.remove_prefix(1); cond_handle = xla::Broadcast(cond_handle, bdims); std::vector dim_order(then_shape.dims()); diff --git a/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc b/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc index 25a5bcbe1dd27d741ce3b74125ba9ce425ee78f3..0c32b8def0f7b741c93e803f8359b6504087e257 100644 --- a/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc @@ -18,7 +18,9 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/numeric.h" #include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" @@ -55,10 +57,10 @@ Status GetIntValue(int index, XlaOpKernelContext* ctx, int64* value) { // The type-specific part of the implementation of Range. template -Status CreateRangeTensor(const xla::LiteralSlice& start_literal, - const xla::LiteralSlice& limit_literal, - const xla::LiteralSlice& delta_literal, - Tensor* output) { +xla::StatusOr CreateRangeTensor( + const xla::LiteralSlice& start_literal, + const xla::LiteralSlice& limit_literal, + const xla::LiteralSlice& delta_literal, xla::XlaBuilder* builder) { T start = start_literal.Get({}); T limit = limit_literal.Get({}); T delta = delta_literal.Get({}); @@ -82,14 +84,10 @@ Status CreateRangeTensor(const xla::LiteralSlice& start_literal, ? ((std::abs(limit - start) + std::abs(delta) - 1) / std::abs(delta)) : std::ceil(std::abs((limit - start) / delta))); - *output = Tensor(DataTypeToEnum::v(), TensorShape({size})); - auto flat = output->flat(); - T val = start; - for (int64 i = 0; i < size; ++i) { - flat(i) = val; - val += delta; - } - return Status::OK(); + return xla::ConstantR0(builder, start) + + xla::ConstantR0(builder, delta) * + xla::Iota(builder, xla::primitive_util::NativeToPrimitiveType(), + size); } class RangeOp : public XlaOpKernel { @@ -115,27 +113,26 @@ class RangeOp : public XlaOpKernel { OP_REQUIRES_OK(ctx, ctx->ConstantInput(2, &delta)); DataType type = input_type(0); - Tensor output; - Status status; + xla::StatusOr output; switch (type) { case DT_INT32: - status = CreateRangeTensor(start, limit, delta, &output); + output = CreateRangeTensor(start, limit, delta, ctx->builder()); break; case DT_INT64: - status = CreateRangeTensor(start, limit, delta, &output); + output = CreateRangeTensor(start, limit, delta, ctx->builder()); break; case DT_FLOAT: - status = CreateRangeTensor(start, limit, delta, &output); + output = CreateRangeTensor(start, limit, delta, ctx->builder()); break; case DT_DOUBLE: - status = CreateRangeTensor(start, limit, delta, &output); + output = CreateRangeTensor(start, limit, delta, ctx->builder()); break; default: - status = errors::InvalidArgument("Invalid type for Range ", + output = errors::InvalidArgument("Invalid type for Range ", DataTypeString(type)); } - OP_REQUIRES_OK(ctx, status); - ctx->SetConstantOutput(0, output); + OP_REQUIRES_OK(ctx, output.status()); + ctx->SetOutput(0, output.ValueOrDie()); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/shape_op.cc b/tensorflow/compiler/tf2xla/kernels/shape_op.cc index 4e0cf99d8e7ff45ed9145981b5e2e637ce4d4e4b..c8a0f31a0375abacaca26688a23f4835e11c692e 100644 --- a/tensorflow/compiler/tf2xla/kernels/shape_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/shape_op.cc @@ -44,7 +44,7 @@ class ShapeOp : public XlaOpKernel { DataType out_dtype_; }; -REGISTER_XLA_OP(Name("Shape").CompilationOnly(), ShapeOp); +REGISTER_XLA_OP(Name("Shape").CompilationOnly().IsMetadataOp(), ShapeOp); class ShapeNOp : public XlaOpKernel { public: @@ -66,7 +66,7 @@ class ShapeNOp : public XlaOpKernel { private: DataType out_dtype_; }; -REGISTER_XLA_OP(Name("ShapeN").CompilationOnly(), ShapeNOp); +REGISTER_XLA_OP(Name("ShapeN").CompilationOnly().IsMetadataOp(), ShapeNOp); class RankOp : public XlaOpKernel { public: @@ -82,7 +82,7 @@ class RankOp : public XlaOpKernel { } }; -REGISTER_XLA_OP(Name("Rank").CompilationOnly(), RankOp); +REGISTER_XLA_OP(Name("Rank").CompilationOnly().IsMetadataOp(), RankOp); class SizeOp : public XlaOpKernel { public: @@ -101,7 +101,7 @@ class SizeOp : public XlaOpKernel { } }; -REGISTER_XLA_OP(Name("Size").CompilationOnly(), SizeOp); +REGISTER_XLA_OP(Name("Size").CompilationOnly().IsMetadataOp(), SizeOp); class ExpandDimsOp : public XlaOpKernel { public: @@ -115,7 +115,7 @@ class ExpandDimsOp : public XlaOpKernel { // accept legacy scalars, even when they should be forbidden by the graphdef // version. OP_REQUIRES(ctx, dim_shape.num_elements() == 1, - errors::InvalidArgument(strings::StrCat( + errors::InvalidArgument(absl::StrCat( "dim input to ExpandDims must be a scalar; got ", dim_shape.DebugString()))); diff --git a/tensorflow/compiler/tf2xla/kernels/slice_op.cc b/tensorflow/compiler/tf2xla/kernels/slice_op.cc index 6adc3c58de63ee70c26bed47eebef955893df4a5..537b71f3c0cf3622a8a45a717ac406da69f5c3c7 100644 --- a/tensorflow/compiler/tf2xla/kernels/slice_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/slice_op.cc @@ -15,6 +15,7 @@ limitations under the License. // XLA-specific Slice Op. +#include "absl/types/span.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" @@ -25,7 +26,6 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/mem.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/sort_ops.cc b/tensorflow/compiler/tf2xla/kernels/sort_ops.cc index aaeeae01ccb303091a6d37d1aeb4b2a3377dc638..6cfdf4a5ae479e9851454df97160754f122bc6ff 100644 --- a/tensorflow/compiler/tf2xla/kernels/sort_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/sort_ops.cc @@ -25,11 +25,26 @@ class XlaSortOp : public XlaOpKernel { explicit XlaSortOp(OpKernelConstruction* context) : XlaOpKernel(context) {} void Compile(XlaOpKernelContext* context) override { - context->SetOutput(0, xla::Sort(context->Input(0))); + context->SetOutput(0, xla::Sort(context->Input("input"))); } }; REGISTER_XLA_OP(Name("XlaSort"), XlaSortOp); +class XlaKeyValueSortOp : public XlaOpKernel { + public: + explicit XlaKeyValueSortOp(OpKernelConstruction* context) + : XlaOpKernel(context) {} + + void Compile(XlaOpKernelContext* context) override { + xla::XlaOp result = + xla::Sort(context->Input("keys"), {context->Input("values")}); + context->SetOutput(0, xla::GetTupleElement(result, 0)); + context->SetOutput(1, xla::GetTupleElement(result, 1)); + } +}; + +REGISTER_XLA_OP(Name("XlaKeyValueSort"), XlaKeyValueSortOp); + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc b/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc index 7327258c31f21f45ff7ffffbc9db7a2a70b4a14c..76b79be6f6f6b5ecbe9edcffb81f2834fdac9a56 100644 --- a/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc @@ -23,10 +23,10 @@ namespace { void SpaceToBatch(XlaOpKernelContext* ctx, const xla::XlaOp& input, DataType input_dtype, const TensorShape& input_tensor_shape, - gtl::ArraySlice block_shape, + absl::Span block_shape, const xla::Literal& paddings) { const int input_rank = input_tensor_shape.dims(); - const gtl::InlinedVector input_shape = + const absl::InlinedVector input_shape = input_tensor_shape.dim_sizes(); const int block_rank = block_shape.size(); @@ -34,7 +34,7 @@ void SpaceToBatch(XlaOpKernelContext* ctx, const xla::XlaOp& input, ctx, input_rank >= 1 + block_rank, errors::InvalidArgument("input rank should be >= ", 1 + block_rank, " instead of ", input_rank)); - gtl::ArraySlice remainder_shape(input_shape); + absl::Span remainder_shape(input_shape); remainder_shape.remove_prefix(1 + block_rank); OP_REQUIRES( diff --git a/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc b/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc index 4493539fe34f0ce635fdc58660d4ff90af9c9379..3293c13b21bc4825c83f494b7f2d48a9b3000f9e 100644 --- a/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc @@ -48,7 +48,7 @@ class SpaceToDepthOp : public XlaOpKernel { OP_REQUIRES(ctx, kRequiredDims == input_rank, errors::InvalidArgument("Input rank should be ", kRequiredDims, "; got ", input_rank)); - const gtl::InlinedVector input_shape = + const absl::InlinedVector input_shape = input_tensor_shape.dim_sizes(); xla::XlaOp input = ctx->Input(0); diff --git a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc index df91900570107609c0f1c2281faaab8a5e65b98b..ee70f508a9586d5f47bd7bb7670506d4c92b369f 100644 --- a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc @@ -111,7 +111,7 @@ class StackOp : public XlaOpKernel { xla::XlaOp value; XlaContext& xc = XlaContext::Get(ctx); XlaResource* resource; - string name = strings::StrCat("Stack: ", stack_name_); + string name = absl::StrCat("Stack: ", stack_name_); OP_REQUIRES_OK( ctx, xc.CreateResource(XlaResource::kStack, -1, std::move(name), dtype_, TensorShape(), value, /*tensor_array_size=*/size, diff --git a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc index 1062399d91bd9a9bf8c3820c5ecac534c110746d..2b2e3de64fd0db9d99efa46ecaf7a0fefbae6645 100644 --- a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/util/strided_slice_op.h" +#include "absl/types/span.h" #include "tensorflow/compiler/tf2xla/literal_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" @@ -25,7 +26,6 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/mem.h" namespace tensorflow { @@ -46,9 +46,9 @@ class StridedSliceOp : public XlaOpKernel { const TensorShape input_shape = ctx->InputShape(0); TensorShape final_shape; - gtl::InlinedVector begin; - gtl::InlinedVector end; - gtl::InlinedVector strides; + absl::InlinedVector begin; + absl::InlinedVector end; + absl::InlinedVector strides; xla::Literal begin_literal, end_literal, strides_literal; OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &begin_literal)); @@ -72,8 +72,8 @@ class StridedSliceOp : public XlaOpKernel { shrink_axis_mask_, &dummy_processing_shape, &final_shape, &dummy, &dummy, &dummy, &begin, &end, &strides)); - gtl::InlinedVector dimensions_to_reverse; - gtl::InlinedVector slice_begin, slice_end, slice_strides; + absl::InlinedVector dimensions_to_reverse; + absl::InlinedVector slice_begin, slice_end, slice_strides; for (int i = 0; i < begin.size(); ++i) { if (strides[i] > 0) { @@ -127,9 +127,9 @@ class StridedSliceGradOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { TensorShape processing_shape, final_shape; - gtl::InlinedVector begin; - gtl::InlinedVector end; - gtl::InlinedVector strides; + absl::InlinedVector begin; + absl::InlinedVector end; + absl::InlinedVector strides; TensorShape input_shape; OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &input_shape)); @@ -175,7 +175,7 @@ class StridedSliceGradOp : public XlaOpKernel { grad = xla::Reshape(grad, processing_shape.dim_sizes()); // Pad the input gradients. - gtl::InlinedVector dimensions_to_reverse; + absl::InlinedVector dimensions_to_reverse; xla::PaddingConfig padding_config; for (int i = 0; i < processing_shape.dims(); ++i) { @@ -238,9 +238,9 @@ class StridedSliceAssignOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { TensorShape final_shape; - gtl::InlinedVector begin; - gtl::InlinedVector end; - gtl::InlinedVector strides; + absl::InlinedVector begin; + absl::InlinedVector end; + absl::InlinedVector strides; xla::Literal begin_literal, end_literal, strides_literal; OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &begin_literal)); @@ -287,8 +287,8 @@ class StridedSliceAssignOp : public XlaOpKernel { xla::XlaOp rhs = ctx->Input(4); - gtl::InlinedVector dimensions_to_reverse; - gtl::InlinedVector slice_begin, slice_dims; + absl::InlinedVector dimensions_to_reverse; + absl::InlinedVector slice_begin, slice_dims; for (int i = 0; i < begin.size(); ++i) { // TODO(phawkins): implement strides != 1 OP_REQUIRES( diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc index be1814d8e3ae2c0ddad0134b9288e0ea084aa81b..06a560d9471c352065ef7e9f6903ebdca542f5b1 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc @@ -122,10 +122,11 @@ Status GetTensorArrayShape(const XlaResource* resource, // relevant slice of 'operand'. xla::XlaOp DynamicAddSlice(xla::XlaBuilder* builder, const xla::XlaOp& operand, const xla::XlaOp& update, - const gtl::ArraySlice& update_dims, - const xla::XlaOp& start_indices) { + absl::Span update_dims, + const xla::XlaOp& start_indices, DataType dtype) { xla::XlaOp current = xla::DynamicSlice(operand, start_indices, update_dims); - xla::XlaOp sum = xla::Add(current, update); + xla::XlaOp sum = + dtype == DT_BOOL ? xla::Or(current, update) : xla::Add(current, update); return xla::DynamicUpdateSlice(operand, sum, start_indices); } @@ -167,7 +168,7 @@ class TensorArrayOp : public XlaOpKernel { XlaContext& xc = XlaContext::Get(ctx); XlaResource* var; - string name = strings::StrCat("TensorArray: ", tensor_array_name_); + string name = absl::StrCat("TensorArray: ", tensor_array_name_); OP_REQUIRES_OK( ctx, xc.CreateResource(XlaResource::kTensorArray, -1, std::move(name), dtype_, shape, value, /*tensor_array_size=*/size, @@ -222,9 +223,16 @@ class TensorArrayWriteOp : public XlaOpKernel { slice_shape.InsertDim(0, 1LL); auto update = xla::Reshape(value, slice_shape.dim_sizes()); - xla::XlaOp written = - DynamicAddSlice(b, ta, update, slice_shape.dim_sizes(), start_indices); - + xla::XlaOp written; + if (resource->tensor_array_multiple_writes_aggregate()) { + written = DynamicAddSlice(b, ta, update, slice_shape.dim_sizes(), + start_indices, dtype_); + } else { + // TODO(b/117569591): Ideally we would report an error in the case that we + // see multiple writes to the same offset. Unfortunately there is no way + // to report errors at the moment, so we silently overwrite. + written = xla::DynamicUpdateSlice(ta, update, start_indices); + } OP_REQUIRES_OK(ctx, resource->SetValue(written)); ctx->SetOutput(0, flow); } @@ -391,7 +399,11 @@ class TensorArrayScatterOp : public XlaOpKernel { } if (scatter_all_elements_in_order) { - ta = xla::Add(ta, value); + if (dtype_ == DT_BOOL) { + ta = xla::Or(ta, value); + } else { + ta = xla::Add(ta, value); + } } else { auto slice_dims = value_shape.dim_sizes(); slice_dims[0] = 1LL; @@ -414,7 +426,7 @@ class TensorArrayScatterOp : public XlaOpKernel { auto start_indices = xla::Pad(xla::Reshape(index, {1}), xla::ConstantR0(b, 0), xla::MakeEdgePaddingConfig({{0, elem_shape.dims()}})); - ta = DynamicAddSlice(b, ta, slice, slice_dims, start_indices); + ta = DynamicAddSlice(b, ta, slice, slice_dims, start_indices, dtype_); } } @@ -522,8 +534,13 @@ class TensorArraySplitOp : public XlaOpKernel { value_shape.DebugString(), " vs. ", ta_shape.DebugString())); - OP_REQUIRES_OK(ctx, resource->SetValue(xla::Add( - ta, xla::Reshape(value, ta_shape.dim_sizes())))); + const xla::XlaOp reshape = xla::Reshape(value, ta_shape.dim_sizes()); + if (dtype_ == DT_BOOL) { + ta = xla::Or(ta, reshape); + } else { + ta = xla::Add(ta, reshape); + } + OP_REQUIRES_OK(ctx, resource->SetValue(ta)); ctx->SetOutput(0, flow); } diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..74d4fcc425bdadb70a7bedf2487deaf6c4a4f7b9 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc @@ -0,0 +1,226 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// XLA TensorList operators. + +#include +#include + +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/partial_tensor_shape.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/bounds_check.h" +#include "tensorflow/core/kernels/concat_lib.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace { + +Status GetTensorListShape(xla::XlaBuilder* builder, xla::XlaOp op, + TensorShape* tensor_list_shape) { + auto shape_or_status = builder->GetShape(op); + if (!shape_or_status.ok()) { + return shape_or_status.status(); + } + xla::Shape shape = shape_or_status.ValueOrDie(); + TF_RET_CHECK(xla::ShapeUtil::IsTuple(shape)); + return XLAShapeToTensorShape(xla::ShapeUtil::GetTupleElementShape(shape, 0), + tensor_list_shape); +} + +class TensorListReserveOp : public XlaOpKernel { + public: + explicit TensorListReserveOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("element_dtype", &dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + TensorShape element_shape; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &element_shape)); + int64 num_elements; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(1, &num_elements)); + + TensorShape tensor_shape; + tensor_shape.AddDim(num_elements); + tensor_shape.AppendShape(element_shape); + + xla::XlaBuilder* b = ctx->builder(); + ctx->SetOutput(0, xla::Tuple(b, {xla::Broadcast(XlaHelpers::Zero(b, dtype_), + tensor_shape.dim_sizes()), + xla::ConstantR0(b, 0)})); + } + + private: + DataType dtype_; + + TF_DISALLOW_COPY_AND_ASSIGN(TensorListReserveOp); +}; + +REGISTER_XLA_OP(Name("TensorListReserve") + .CompileTimeConstInput("element_shape") + .CompileTimeConstInput("num_elements"), + TensorListReserveOp); + +class EmptyTensorListOp : public XlaOpKernel { + public: + explicit EmptyTensorListOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + + void Compile(XlaOpKernelContext* ctx) override { + ctx->CtxFailure( + errors::InvalidArgument("XLA compilation requires a fixed tensor list " + "size. Use TensorListReserve instead.")); + } + + private: + TF_DISALLOW_COPY_AND_ASSIGN(EmptyTensorListOp); +}; + +REGISTER_XLA_OP(Name("EmptyTensorList"), EmptyTensorListOp); + +class TensorListElementShapeOp : public XlaOpKernel { + public: + explicit TensorListElementShapeOp(OpKernelConstruction* ctx) + : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("shape_type", &shape_type_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + xla::XlaBuilder* b = ctx->builder(); + TensorShape shape; + OP_REQUIRES_OK(ctx, GetTensorListShape(b, ctx->Input(0), &shape)); + shape.RemoveDim(0); + + switch (shape_type_) { + case DT_INT64: + ctx->SetOutput(0, xla::ConstantR1(b, shape.dim_sizes())); + break; + case DT_INT32: { + std::vector size; + for (int64 s : shape.dim_sizes()) { + size.push_back(s); + } + ctx->SetOutput(0, xla::ConstantR1(b, size)); + break; + } + default: + ctx->CtxFailure( + errors::InvalidArgument("Unsupported shape type requested")); + return; + } + } + + private: + DataType shape_type_; + + TF_DISALLOW_COPY_AND_ASSIGN(TensorListElementShapeOp); +}; + +REGISTER_XLA_OP(Name("TensorListElementShape"), TensorListElementShapeOp); + +class TensorListPushBackOp : public XlaOpKernel { + public: + explicit TensorListPushBackOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("element_dtype", &dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + xla::XlaBuilder* b = ctx->builder(); + xla::XlaOp list = ctx->Input(0); + TensorShape elem_shape = ctx->InputShape(1); + + xla::XlaOp ta = xla::GetTupleElement(list, 0); + xla::XlaOp index = xla::GetTupleElement(list, 1); + xla::XlaOp value = ctx->Input(1); + + // start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0]. + auto start_indices = + xla::Pad(xla::Reshape(index, {1}), xla::ConstantR0(b, 0), + xla::MakeEdgePaddingConfig({{0, elem_shape.dims()}})); + + TensorShape slice_shape = elem_shape; + slice_shape.InsertDim(0, 1LL); + auto update = xla::Reshape(value, slice_shape.dim_sizes()); + + // TODO(phawkins): We don't check the index is in bounds --- there is no + // error mechanism in XLA. + ctx->SetOutput( + 0, xla::Tuple(b, {xla::DynamicUpdateSlice(ta, update, start_indices), + index + xla::ConstantR0(b, 1)})); + } + + private: + DataType dtype_; + + TF_DISALLOW_COPY_AND_ASSIGN(TensorListPushBackOp); +}; + +REGISTER_XLA_OP(Name("TensorListPushBack"), TensorListPushBackOp); + +class TensorListPopBackOp : public XlaOpKernel { + public: + explicit TensorListPopBackOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("element_dtype", &dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + xla::XlaBuilder* b = ctx->builder(); + xla::XlaOp state = ctx->Input(0); + + TensorShape shape; + OP_REQUIRES_OK(ctx, GetTensorListShape(b, state, &shape)); + + xla::XlaOp ta = xla::GetTupleElement(state, 0); + xla::XlaOp index = xla::GetTupleElement(state, 1); + + index = index - xla::ConstantR0(b, 1); + + // start_indices of the DynamicSlice are [index, 0, 0, ..., 0]. + auto start_indices = + xla::Pad(xla::Reshape(index, {1}), xla::ConstantR0(b, 0), + xla::MakeEdgePaddingConfig({{0, shape.dims() - 1}})); + + auto slice_shape = shape.dim_sizes(); + slice_shape[0] = 1LL; + + // TODO(phawkins): We don't check the index is in bounds --- there is no + // error mechanism in XLA. + xla::XlaOp read = xla::DynamicSlice(ta, start_indices, slice_shape); + // Remove the leading '1' dimension. + std::vector value_shape(slice_shape.begin() + 1, slice_shape.end()); + + ctx->SetOutput(0, xla::Tuple(b, {ta, index})); + ctx->SetOutput(1, xla::Reshape(read, value_shape)); + } + + private: + DataType dtype_; + + TF_DISALLOW_COPY_AND_ASSIGN(TensorListPopBackOp); +}; + +REGISTER_XLA_OP(Name("TensorListPopBack"), TensorListPopBackOp); + +} // anonymous namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/tile_ops.cc b/tensorflow/compiler/tf2xla/kernels/tile_ops.cc index 2c7213f322eb6fec1f134a444b569ae72307d00f..52f2b36e19edd96f491f6706d1872e0d3af2df3b 100644 --- a/tensorflow/compiler/tf2xla/kernels/tile_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tile_ops.cc @@ -16,6 +16,7 @@ limitations under the License. // XLA-specific Tile Op. #include +#include "absl/types/span.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" @@ -26,7 +27,6 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/type_index.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/macros.h" namespace tensorflow { @@ -96,7 +96,11 @@ class TileOp : public XlaOpKernel { // operation broadcast semantics. auto broadcasted_zero = xla::Broadcast( XlaHelpers::Zero(ctx->builder(), ctx->input_type(0)), output_shape); - ctx->SetOutput(0, xla::Add(broadcasted_zero, input)); + if (ctx->input_type(0) == DT_BOOL) { + ctx->SetOutput(0, xla::Or(broadcasted_zero, input)); + } else { + ctx->SetOutput(0, xla::Add(broadcasted_zero, input)); + } return; } diff --git a/tensorflow/compiler/tf2xla/kernels/training_ops.cc b/tensorflow/compiler/tf2xla/kernels/training_ops.cc index be5e91138656716daddcc3c7a68dbb78ecb69103..7077c2e3a546e198bdb4ff944ea531f3158810f2 100644 --- a/tensorflow/compiler/tf2xla/kernels/training_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/training_ops.cc @@ -688,7 +688,7 @@ void CompileFtrl(XlaOpKernelContext* ctx, DataType dtype, } // grad_to_use = grad + 2 * l2_shrinkage * var - // new_accum = accum + grad_to_use * grad_to_use + // new_accum = accum + grad * grad // linear += grad_to_use - // (new_accum^(-lr_power) - accum^(-lr_power)) / lr * var // quadratic = (new_accum^(-lr_power) / lr) + 2 * l2 @@ -704,7 +704,7 @@ void CompileFtrl(XlaOpKernelContext* ctx, DataType dtype, grad_to_use = grad; } - xla::XlaOp new_accum = accum + xla::Square(grad_to_use); + xla::XlaOp new_accum = accum + xla::Square(grad); xla::XlaOp new_accum_lr_pow = xla::Pow(new_accum, -lr_power); xla::XlaOp accum_lr_pow = xla::Pow(accum, -lr_power); linear = linear + grad_to_use - (new_accum_lr_pow - accum_lr_pow) / lr * var; diff --git a/tensorflow/compiler/tf2xla/kernels/transpose_op.cc b/tensorflow/compiler/tf2xla/kernels/transpose_op.cc index f9148b394212777271f9eba51313ee17b19819af..6b303b31d43ce2249a87f25723caf34f84c8387d 100644 --- a/tensorflow/compiler/tf2xla/kernels/transpose_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/transpose_op.cc @@ -61,7 +61,7 @@ class TransposeOp : public XlaOpKernel { std::vector transposed_order; // Check whether permutation is a permutation of integers of [0 .. dims). - gtl::InlinedVector bits(dims); + absl::InlinedVector bits(dims); bool is_identity = true; for (int i = 0; i < dims; ++i) { const int32 d = perm[i]; diff --git a/tensorflow/compiler/tf2xla/kernels/while_op.cc b/tensorflow/compiler/tf2xla/kernels/while_op.cc index 296518229ebf0ba46717afc4f26d5ae1551c2862..559414eeaa5fec75e5a9d1866baaf738c024cd15 100644 --- a/tensorflow/compiler/tf2xla/kernels/while_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/while_op.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/kernels/while_op.h" #include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/side_effect_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" @@ -90,6 +91,11 @@ XlaWhileOp::XlaWhileOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { cond_name_attr_ = *name_attr; OP_REQUIRES_OK(ctx, ctx->GetAttr("body", &name_attr)); body_name_attr_ = *name_attr; + if (!ctx->GetAttr(kXlaTokenInputNodesAttrName, &token_input_nodes_).ok()) { + has_token_input_output_ = false; + } else { + has_token_input_output_ = !token_input_nodes_.empty(); + } } void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { @@ -120,6 +126,7 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { body_options.return_updated_values_for_all_resources = true; body_options.resolve_compile_time_constants = false; body_options.is_entry_computation = false; + body_options.add_token_input_output = has_token_input_output_; XlaCompiler::CompilationResult body; OP_REQUIRES_OK(ctx, compiler->CompileFunction(body_options, body_name_attr_, arguments, &body)); @@ -192,6 +199,7 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { cond_options.use_tuple_arg = true; cond_options.resolve_compile_time_constants = false; cond_options.is_entry_computation = false; + cond_options.add_token_input_output = has_token_input_output_; XlaCompiler::CompilationResult cond; OP_REQUIRES_OK(ctx, compiler->CompileFunction(cond_options, cond_name_attr_, arguments, &cond)); @@ -238,7 +246,16 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { std::vector inputs(num_inputs); for (int i = 0; i < num_inputs; ++i) { int input_num = body.input_mapping[i]; - if (ctx->input_type(input_num) == DT_RESOURCE) { + if (has_token_input_output_ && i == num_inputs - 1) { + // Set token input for this "while" op. + std::vector token_inputs; + for (const string& node_name : token_input_nodes_) { + auto token_or = compiler->GetNodeToken(node_name); + OP_REQUIRES_OK(ctx, token_or.status()); + token_inputs.push_back(token_or.ValueOrDie()); + } + inputs[i] = xla::AfterAll(builder, token_inputs); + } else if (ctx->input_type(input_num) == DT_RESOURCE) { XlaResource* resource; OP_REQUIRES_OK(ctx, ctx->GetResourceInput(input_num, &resource)); OP_REQUIRES_OK(ctx, resource->Pack(&inputs[i], builder)); @@ -273,6 +290,18 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { xla::GetTupleElement(while_result, i)); } } + if (has_token_input_output_) { + // Set token output for this "while" op. + xla::XlaOp token_output = + xla::GetTupleElement(while_result, ctx->num_outputs()); + auto shape_or = builder->GetShape(token_output); + OP_REQUIRES_OK(ctx, shape_or.status()); + OP_REQUIRES(ctx, xla::ShapeUtil::IsToken(shape_or.ValueOrDie()), + errors::FailedPrecondition( + "Token output is not token type: ", + xla::ShapeUtil::HumanString(shape_or.ValueOrDie()))); + OP_REQUIRES_OK(ctx, compiler->SetNodeToken(name(), token_output)); + } // Updates the values of any resource variables modified by the loop. for (int i = 0; i < body.resource_updates.size(); ++i) { diff --git a/tensorflow/compiler/tf2xla/kernels/while_op.h b/tensorflow/compiler/tf2xla/kernels/while_op.h index 67edebabf9f643a919d0f06c228e2d224a49a2af..aeeff40e68f8b778628b9e85bd9b4ddcb73883a5 100644 --- a/tensorflow/compiler/tf2xla/kernels/while_op.h +++ b/tensorflow/compiler/tf2xla/kernels/while_op.h @@ -56,6 +56,8 @@ class XlaWhileOp : public XlaOpKernel { private: NameAttrList cond_name_attr_; NameAttrList body_name_attr_; + bool has_token_input_output_; + std::vector token_input_nodes_; TF_DISALLOW_COPY_AND_ASSIGN(XlaWhileOp); }; diff --git a/tensorflow/compiler/tf2xla/kernels/xla_conv_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_conv_op.cc index 8848623868091f8d19b1622f23ba23c68689d90d..fecc7c556eb4121b912796e5811632c46769b479 100644 --- a/tensorflow/compiler/tf2xla/kernels/xla_conv_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/xla_conv_op.cc @@ -84,7 +84,7 @@ class XlaConvOp : public XlaOpKernel { private: xla::ConvolutionDimensionNumbers dnums_; - xla::PrecisionConfigProto precision_config_; + xla::PrecisionConfig precision_config_; TF_DISALLOW_COPY_AND_ASSIGN(XlaConvOp); }; diff --git a/tensorflow/compiler/tf2xla/kernels/xla_dot_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_dot_op.cc index 2fed53e5c072e1a50e0f07f45357ee86c90f986f..40b15b5579ab9862b9d30df74af9877c98c4aa2c 100644 --- a/tensorflow/compiler/tf2xla/kernels/xla_dot_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/xla_dot_op.cc @@ -54,7 +54,7 @@ class XlaDotOp : public XlaOpKernel { private: xla::DotDimensionNumbers dnums_; - xla::PrecisionConfigProto precision_config_; + xla::PrecisionConfig precision_config_; TF_DISALLOW_COPY_AND_ASSIGN(XlaDotOp); }; diff --git a/tensorflow/compiler/tf2xla/lib/BUILD b/tensorflow/compiler/tf2xla/lib/BUILD index 99511e991422014c877fb5f6b7fb6a914e730f40..1ce3930fd1cd91f8e8dfb765b49be2dc969d1bd7 100644 --- a/tensorflow/compiler/tf2xla/lib/BUILD +++ b/tensorflow/compiler/tf2xla/lib/BUILD @@ -31,6 +31,22 @@ cc_library( ], ) +cc_library( + name = "broadcast", + srcs = ["broadcast.cc"], + hdrs = ["broadcast.h"], + deps = [ + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/client:xla_builder", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + ], +) + cc_library( name = "cholesky", srcs = ["cholesky.cc"], @@ -104,6 +120,7 @@ cc_library( "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:arithmetic", "//tensorflow/core:lib", + "@com_google_absl//absl/types:span", ], ) @@ -166,6 +183,7 @@ cc_library( "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/core:lib", + "@com_google_absl//absl/types:span", ], ) @@ -203,6 +221,7 @@ cc_library( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/core:lib", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", ], ) diff --git a/tensorflow/compiler/tf2xla/lib/batch_dot.cc b/tensorflow/compiler/tf2xla/lib/batch_dot.cc index d8c050d09e871c80e128989c9fbdb57c266b19ed..5400e8834cb9807f6dd71abe7789b2672e29e905 100644 --- a/tensorflow/compiler/tf2xla/lib/batch_dot.cc +++ b/tensorflow/compiler/tf2xla/lib/batch_dot.cc @@ -28,7 +28,7 @@ namespace tensorflow { xla::XlaOp BatchDot(xla::XlaOp x, xla::XlaOp y, bool transpose_x, bool transpose_y, bool conjugate_x, bool conjugate_y, - xla::PrecisionConfigProto::Precision precision) { + xla::PrecisionConfig::Precision precision) { xla::XlaBuilder* builder = x.builder(); return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { TF_ASSIGN_OR_RETURN(xla::Shape x_shape, builder->GetShape(x)); @@ -96,20 +96,10 @@ xla::XlaOp BatchDot(xla::XlaOp x, xla::XlaOp y, bool transpose_x, y = xla::Conj(y); } - xla::PrecisionConfigProto precision_proto; + xla::PrecisionConfig precision_proto; precision_proto.add_operand_precision(precision); precision_proto.add_operand_precision(precision); - // If there are no batch dimensions, use a regular Dot. - // TODO(b/69062148) Remove this code when Dot emitters can be passed - // dimensions to transpose directly (i.e. without requiring a Transpose - // HLO). - if (batch_dimension_numbers.empty()) { - auto lhs = transpose_x ? xla::Transpose(x, {1, 0}) : x; - auto rhs = transpose_y ? xla::Transpose(y, {1, 0}) : y; - return xla::Dot(lhs, rhs, &precision_proto); - } - xla::DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(x_inner_dim); dot_dnums.add_rhs_contracting_dimensions(y_inner_dim); diff --git a/tensorflow/compiler/tf2xla/lib/batch_dot.h b/tensorflow/compiler/tf2xla/lib/batch_dot.h index 6cfccd55530ff40a309673d57d1fe61fc8264316..6edd63a4d3b66c21aa4cce8c9f36eef0dc363cd8 100644 --- a/tensorflow/compiler/tf2xla/lib/batch_dot.h +++ b/tensorflow/compiler/tf2xla/lib/batch_dot.h @@ -43,11 +43,11 @@ namespace tensorflow { // It is computed as: // // output[..., :, :] = matrix(x[..., :, :]) * matrix(y[..., :, :]) -xla::XlaOp BatchDot(xla::XlaOp x, xla::XlaOp y, bool transpose_x = false, - bool transpose_y = false, bool conjugate_x = false, - bool conjugate_y = false, - xla::PrecisionConfigProto::Precision precision = - xla::PrecisionConfigProto::DEFAULT); +xla::XlaOp BatchDot( + xla::XlaOp x, xla::XlaOp y, bool transpose_x = false, + bool transpose_y = false, bool conjugate_x = false, + bool conjugate_y = false, + xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::DEFAULT); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/broadcast.cc b/tensorflow/compiler/tf2xla/lib/broadcast.cc new file mode 100644 index 0000000000000000000000000000000000000000..3e402ef855cd7c114332d84032bc869232404fc8 --- /dev/null +++ b/tensorflow/compiler/tf2xla/lib/broadcast.cc @@ -0,0 +1,93 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/lib/broadcast.h" + +#include + +#include "absl/algorithm/container.h" +#include "absl/strings/str_join.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/util.h" + +namespace tensorflow { + +xla::StatusOr BroadcastTo(xla::XlaOp input, + absl::Span output_dims) { + xla::XlaBuilder* builder = input.builder(); + TF_ASSIGN_OR_RETURN(xla::Shape input_shape, builder->GetShape(input)); + absl::Span input_dims = + xla::AsInt64Slice(input_shape.dimensions()); + + if (input_dims == output_dims) { + return input; + } + + if (input_dims.size() > output_dims.size()) { + return errors::InvalidArgument( + "Input shape (", xla::ShapeUtil::HumanString(input_shape), + ") must have rank less than or equal to the output shape [", + absl::StrJoin(output_dims, ","), "]"); + } + + std::vector broadcast_dims; + std::vector broadcast_shape; + auto input_it = input_dims.rbegin(); + for (auto output_it = output_dims.rbegin(); output_it != output_dims.rend(); + ++output_it) { + if (input_it != input_dims.rend()) { + if (!(*output_it == 0 && *input_it == 0) && + !(*input_it != 0 && *output_it % *input_it == 0)) { + return errors::InvalidArgument("Invalid shape broadcast from ", + xla::ShapeUtil::HumanString(input_shape), + " to [", absl::StrJoin(output_dims, ","), + "]"); + } + + broadcast_dims.push_back(broadcast_shape.size()); + if (*output_it == *input_it) { + broadcast_shape.push_back(*output_it); + } else if (*output_it != *input_it) { + // Add dimensions [I, O/I], which we will later flatten to just + // [O]. We must do this in two phases since XLA broadcasting does not + // support tiling. + broadcast_shape.push_back(*input_it); + broadcast_shape.push_back(*output_it / *input_it); + } + ++input_it; + } else { + broadcast_shape.push_back(*output_it); + } + } + TF_RET_CHECK(input_it == input_dims.rend()); + + absl::c_reverse(broadcast_dims); + int broadcast_shape_size = broadcast_shape.size(); + for (int64& broadcast_dim : broadcast_dims) { + broadcast_dim = broadcast_shape_size - broadcast_dim - 1; + } + absl::c_reverse(broadcast_shape); + xla::XlaOp output = xla::BroadcastInDim( + input, + xla::ShapeUtil::MakeShape(input_shape.element_type(), broadcast_shape), + broadcast_dims); + if (broadcast_shape != output_dims) { + output = xla::Reshape(output, output_dims); + } + return output; +} + +} // namespace tensorflow diff --git a/tensorflow/core/util/status_util.h b/tensorflow/compiler/tf2xla/lib/broadcast.h similarity index 50% rename from tensorflow/core/util/status_util.h rename to tensorflow/compiler/tf2xla/lib/broadcast.h index ea92f61dce0b4e3a9470e25d96dbb599954ea46f..591e696f06b994a7fdea58bc95ba785f683ce7d1 100644 --- a/tensorflow/core/util/status_util.h +++ b/tensorflow/compiler/tf2xla/lib/broadcast.h @@ -13,24 +13,20 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CORE_UTIL_STATUS_UTIL_H_ -#define TENSORFLOW_CORE_UTIL_STATUS_UTIL_H_ +#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_BROADCAST_H_ +#define TENSORFLOW_COMPILER_TF2XLA_LIB_BROADCAST_H_ -#include "tensorflow/core/graph/graph.h" -#include "tensorflow/core/lib/strings/strcat.h" +#include "absl/types/span.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/statusor.h" namespace tensorflow { -// Creates a tag to be used in an exception error message. This can be parsed by -// the Python layer and replaced with information about the node. -// -// For example, error_format_tag(node, "${file}") returns -// "^^node:NODE_NAME:${line}^^" which would be rewritten by the Python layer as -// e.g. "file/where/node/was/created.py". -inline string error_format_tag(const Node& node, const string& format) { - return strings::StrCat("^^node:", node.name(), ":", format, "^^"); -} +// Broadcasts 'input' up to shape 'output_dims', using TensorFlow broadcasting +// rules. Supports broadcasting a dimension of size x to size x*y, i.e., tiling. +xla::StatusOr BroadcastTo(xla::XlaOp input, + absl::Span output_dims); } // namespace tensorflow -#endif // TENSORFLOW_CORE_UTIL_STATUS_UTIL_H_ +#endif // TENSORFLOW_COMPILER_TF2XLA_LIB_BROADCAST_H_ diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.cc b/tensorflow/compiler/tf2xla/lib/cholesky.cc index 67fb56510cbd0677a2b78e2090f98b602539c6bd..ab3d0a566839343828d176d9a46672824e425613 100644 --- a/tensorflow/compiler/tf2xla/lib/cholesky.cc +++ b/tensorflow/compiler/tf2xla/lib/cholesky.cc @@ -50,20 +50,21 @@ namespace { // l[..., j, j] // return l xla::XlaOp CholeskyUnblocked(xla::XlaOp a, - xla::PrecisionConfigProto::Precision precision) { + xla::PrecisionConfig::Precision precision) { xla::XlaBuilder* builder = a.builder(); return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); const int n_dims = xla::ShapeUtil::Rank(a_shape); const int64 n = xla::ShapeUtil::GetDimension(a_shape, -1); - gtl::ArraySlice major_dims(xla::AsInt64Slice(a_shape.dimensions()), - /*pos=*/0, - /*len=*/n_dims - 2); + auto major_dims = xla::AsInt64Slice(a_shape.dimensions()) + .subspan( + /*pos=*/0, + /*len=*/n_dims - 2); xla::XlaOp l = xla::ZerosLike(a); // Construct the for loop body to iterate over rows. - auto body_fn = [&](xla::XlaOp i, gtl::ArraySlice loop_vars, + auto body_fn = [&](xla::XlaOp i, absl::Span loop_vars, xla::XlaBuilder* body_builder) -> xla::StatusOr> { xla::Shape col_shape; @@ -149,7 +150,7 @@ xla::XlaOp CholeskyUnblocked(xla::XlaOp a, } // namespace xla::XlaOp Cholesky(xla::XlaOp a, int64 block_size, - xla::PrecisionConfigProto::Precision precision) { + xla::PrecisionConfig::Precision precision) { xla::XlaBuilder* builder = a.builder(); return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.h b/tensorflow/compiler/tf2xla/lib/cholesky.h index 60cd7ded53fe862f29ca2bb68b175fcd1c89b70c..9a561c34b92ee45059f2a05336e682838f8e36e2 100644 --- a/tensorflow/compiler/tf2xla/lib/cholesky.h +++ b/tensorflow/compiler/tf2xla/lib/cholesky.h @@ -30,9 +30,9 @@ namespace tensorflow { // TODO(phawkins): check for negative values on the diagonal and return an // error, instead of silently yielding NaNs. // TODO(znado): handle the complex Hermitian case -xla::XlaOp Cholesky(xla::XlaOp a, int64 block_size = 256, - xla::PrecisionConfigProto::Precision precision = - xla::PrecisionConfigProto::HIGHEST); +xla::XlaOp Cholesky( + xla::XlaOp a, int64 block_size = 256, + xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::HIGHEST); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/qr.cc b/tensorflow/compiler/tf2xla/lib/qr.cc index b6f30d8d49bf05813fa6fccc4544b0631f866490..6b3f2b6e065b5c99e2d0248237369ecc30188aa5 100644 --- a/tensorflow/compiler/tf2xla/lib/qr.cc +++ b/tensorflow/compiler/tf2xla/lib/qr.cc @@ -65,9 +65,9 @@ namespace { // return (v, tau, beta) // TODO(phawkins): LAPACK's xLARFG implementation has code for handling // overflows in the norm/beta calculations. Perhaps do the same here. -xla::Status House(xla::XlaOp x, xla::XlaOp k, gtl::ArraySlice batch_dims, - const int64 m, xla::XlaOp* v, xla::XlaOp* tau, - xla::XlaOp* beta) { +xla::Status House(xla::XlaOp x, xla::XlaOp k, + absl::Span batch_dims, const int64 m, + xla::XlaOp* v, xla::XlaOp* tau, xla::XlaOp* beta) { xla::XlaBuilder* const builder = x.builder(); TF_ASSIGN_OR_RETURN(xla::Shape x_shape, builder->GetShape(x)); const xla::PrimitiveType type = x_shape.element_type(); @@ -150,7 +150,7 @@ struct QRBlockResult { xla::XlaOp vs; // Shape: [..., m, n] }; xla::StatusOr QRBlock( - xla::XlaOp a, xla::PrecisionConfigProto::Precision precision) { + xla::XlaOp a, xla::PrecisionConfig::Precision precision) { xla::XlaBuilder* builder = a.builder(); TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); const int num_dims = xla::ShapeUtil::Rank(a_shape); @@ -173,7 +173,7 @@ xla::StatusOr QRBlock( std::iota(batch_dim_indices.begin(), batch_dim_indices.end(), 0); auto qr_body_fn = - [&](xla::XlaOp j, gtl::ArraySlice values, + [&](xla::XlaOp j, absl::Span values, xla::XlaBuilder* builder) -> xla::StatusOr> { auto a = values[0]; auto vs = values[1]; @@ -255,15 +255,15 @@ xla::StatusOr QRBlock( // There is no need to return Y since at termination of the loop it is equal to // vs. xla::StatusOr ComputeWYRepresentation( - xla::PrimitiveType type, gtl::ArraySlice batch_dims, xla::XlaOp vs, + xla::PrimitiveType type, absl::Span batch_dims, xla::XlaOp vs, xla::XlaOp taus, int64 m, int64 n, - xla::PrecisionConfigProto::Precision precision) { + xla::PrecisionConfig::Precision precision) { std::vector batch_dim_indices(batch_dims.size()); std::iota(batch_dim_indices.begin(), batch_dim_indices.end(), 0); int64 n_index = batch_dims.size() + 1; auto body_fn = - [&](xla::XlaOp j, gtl::ArraySlice values, + [&](xla::XlaOp j, absl::Span values, xla::XlaBuilder* builder) -> xla::StatusOr> { auto w = values[0]; auto y = values[1]; @@ -331,8 +331,8 @@ xla::StatusOr ComputeWYRepresentation( // TODO(phawkins): consider using UT transformations (in the form I - V U V') // rather than WY transformations. xla::StatusOr QRDecomposition( - xla::XlaOp a, int64 block_size, - xla::PrecisionConfigProto::Precision precision) { + xla::XlaOp a, bool full_matrices, int64 block_size, + xla::PrecisionConfig::Precision precision) { xla::XlaBuilder* builder = a.builder(); TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); const int num_dims = xla::ShapeUtil::Rank(a_shape); @@ -396,6 +396,13 @@ xla::StatusOr QRDecomposition( q = UpdateSliceInMinorDims(q, q_panel, {0, i}); } QRDecompositionResult result; + + // full_matrices is false when only a partial result in needed. Slice to the + // needed dimensions here. + if (!full_matrices) { + q = SliceInMinorDims(q, {0, 0}, {m, p}); + a = SliceInMinorDims(a, {0, 0}, {p, n}); + } result.q = q; result.r = a; return result; diff --git a/tensorflow/compiler/tf2xla/lib/qr.h b/tensorflow/compiler/tf2xla/lib/qr.h index 05565477b6062618a75f929b69c38938ddfd7a5a..24b537ac8b63b93e734c3d0e335ea455f7d51a54 100644 --- a/tensorflow/compiler/tf2xla/lib/qr.h +++ b/tensorflow/compiler/tf2xla/lib/qr.h @@ -34,9 +34,8 @@ struct QRDecompositionResult { }; xla::StatusOr QRDecomposition( - xla::XlaOp a, int64 block_size = 128, - xla::PrecisionConfigProto::Precision precision = - xla::PrecisionConfigProto::HIGHEST); + xla::XlaOp a, bool full_matrices, int64 block_size = 128, + xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::HIGHEST); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/scatter.cc b/tensorflow/compiler/tf2xla/lib/scatter.cc index ba22eff73abab11abeb57283c63318b2e50a9ca1..2b1c2ced925d9fee7392986015a6e716a94d356f 100644 --- a/tensorflow/compiler/tf2xla/lib/scatter.cc +++ b/tensorflow/compiler/tf2xla/lib/scatter.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/tf2xla/lib/util.h" #include "tensorflow/compiler/tf2xla/lib/while_loop.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" @@ -27,7 +28,6 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/gtl/array_slice.h" namespace tensorflow { @@ -38,12 +38,10 @@ xla::StatusOr XlaScatter( combiner, xla::XlaBuilder* builder) { TF_ASSIGN_OR_RETURN(xla::Shape buffer_shape, builder->GetShape(buffer)); - TF_RETURN_IF_ERROR(builder->GetShape(updates).status()); + TF_ASSIGN_OR_RETURN(xla::Shape updates_shape, builder->GetShape(updates)); TF_ASSIGN_OR_RETURN(xla::Shape indices_shape, builder->GetShape(indices)); - gtl::ArraySlice indices_dims = + absl::Span indices_dims = xla::AsInt64Slice(indices_shape.dimensions()); - gtl::ArraySlice buffer_dims = - xla::AsInt64Slice(buffer_shape.dimensions()); // If the indices are N-dimensional, the minor dimension of indices contains // the indices to update. Otherwise the indices are all scalars. @@ -58,7 +56,7 @@ xla::StatusOr XlaScatter( ") must be <= the rank of the buffer (shape: ", xla::ShapeUtil::HumanString(buffer_shape), ")"); } - indices_dims.pop_back(); + indices_dims.remove_suffix(1); } int64 num_indices = 1; @@ -81,104 +79,129 @@ xla::StatusOr XlaScatter( } } - // Shape of the non-indexed dimensions of the buffer. - std::vector buffer_shape_post_axes( - buffer_dims.begin() + num_index_dims, buffer_dims.end()); - - // Flatten the major dimensions of indices and updates into a single dimension - // for ease of iteration. - std::vector flat_indices_shape({num_indices}); - if (indices_are_vectors) { - flat_indices_shape.push_back(num_index_dims); + // Example of a 1-D scatter that updates two [3,1] tensors in a tensor of + // shape [3,3]: + // NOTE: ***This case will not be generated by any of the tf.scatter ops.*** + // + // operand = s32[3,3] parameter(0) + // indices = s32[2] parameter(1) + // updates = s32[3,2] parameter(2) + // scatter = s32[3,3] scatter(operand, indices, updates), + // to_apply=update_computation, + // update_window_dims={0}, + // inserted_window_dims={1}, + // scatter_dims_to_operand_dims={1}, + // index_vector_dim=1 + // + // + // Example of a 1-D scatter that updates two [1,3] tensors in a tensor of + // shape [3,3]: + // + // operand = s32[3,3] parameter(0) + // indices = s32[2] parameter(1) + // updates = s32[2,3] parameter(2) + // scatter = s32[3,3] scatter(operand, indices, updates), + // to_apply=update_computation, + // update_window_dims={1}, + // inserted_window_dims={0}, + // scatter_dims_to_operand_dims={0}, + // index_vector_dim=1 + // + // + // Example of an N-D scatter updating slices of shape [1,1,2] in a tensor of + // shape [3,3,2] + // + // operand = s32[3,3,2] parameter(0) + // indices = s32[2,2] parameter(1) + // updates = s32[2,2] parameter(2) + // scatter = s32[3,3,2] scatter(operand, indices, updates), + // to_apply=update_computation, + // update_window_dims={1}, + // inserted_window_dims={0,1}, + // scatter_dims_to_operand_dims={0,1}, + // index_vector_dim=1 + // + // + // Example of a scatter updating slices of shape [] in a tensor of shape [1,1] + // + // operand = s32[1,1] parameter(0) + // indices = s32[1] parameter(1) + // updates = s32[1] parameter(2) + // scatter = s32[1,1] scatter(operand, indices, updates), + // to_apply=update_computation, + // update_window_dims={}, + // inserted_window_dims={0,1}, + // scatter_dims_to_operand_dims={0}, + // index_vector_dim=1 + // Note that updates operand would be broadcasted into [1] in this case. + // + + xla::ScatterDimensionNumbers dim_numbers; + dim_numbers.set_index_vector_dim(indices_are_vectors + ? indices_shape.dimensions_size() - 1 + : indices_shape.dimensions_size()); + + int64 updates_rank = xla::ShapeUtil::Rank(updates_shape); + int64 buffer_rank = xla::ShapeUtil::Rank(buffer_shape); + int64 num_window_dims_in_updates = buffer_rank - num_index_dims; + + // If the rank of `updates` is 0 and does not match the expected rank of + // updates, broadcast `updates` to the expected shape of updates. + auto new_updates = updates; + std::vector expected_updates_dims(indices_dims.begin(), + indices_dims.end()); + for (int64 dim = num_index_dims; dim < buffer_rank; ++dim) { + expected_updates_dims.push_back(buffer_shape.dimensions(dim)); + } + int64 expected_updates_rank = expected_updates_dims.size(); + if (updates_rank == 0 && expected_updates_rank != 0) { + new_updates = xla::Broadcast(updates, expected_updates_dims); + TF_ASSIGN_OR_RETURN(updates_shape, builder->GetShape(new_updates)); + updates_rank = xla::ShapeUtil::Rank(updates_shape); } - std::vector flat_updates_shape({num_indices}); - flat_updates_shape.insert(flat_updates_shape.end(), - buffer_shape_post_axes.begin(), - buffer_shape_post_axes.end()); - - // Construct the initial values of the loop-carried Tensors. - auto flat_indices = xla::Reshape(indices, flat_indices_shape); - auto flat_updates = xla::Reshape(updates, flat_updates_shape); - auto init = {flat_indices, flat_updates, buffer}; - - // Constructs the loop body. The implementation of scatter is essentially: - // for i in range(num_indices): - // index = dynamic-slice(indices, i) - // update = dynamic-slice(updates, i) - // buffer = dynamic-update-slice(buffer, update, index) - auto body_fn = [&](xla::XlaOp i, gtl::ArraySlice loop_vars, - xla::XlaBuilder* body_builder) { - auto indices = loop_vars[0]; - auto updates = loop_vars[1]; - auto buffer = loop_vars[2]; - - auto zero_index = xla::ConstantLiteral( - body_builder, xla::LiteralUtil::Zero(indices_shape.element_type())); - - // Slice the i-th index from the indices array. - xla::XlaOp index; - auto indices_offset = xla::Reshape(i, {1}); - if (indices_are_vectors) { - indices_offset = xla::Pad(indices_offset, zero_index, - xla::MakeEdgePaddingConfig({{0, 1}})); - - index = xla::DynamicSlice(indices, indices_offset, {1, num_index_dims}); - index = xla::Collapse(index, {0, 1}); - } else { - index = xla::DynamicSlice(indices, indices_offset, {1}); + if (updates_rank > 0) { + for (int64 i = (updates_rank - num_window_dims_in_updates); + i < updates_rank; ++i) { + dim_numbers.add_update_window_dims(i); } + } - // Discard updates with negative indices, since some users expect this. - auto index_in_range = xla::ReduceAll( - xla::Le(zero_index, index), xla::ConstantR0(body_builder, true), - xla::CreateScalarAndComputation(xla::PRED, body_builder)); - - // Make the index in bounds to prevent implementation defined behavior. - index = xla::Max(index, zero_index); - index = xla::Pad( - index, zero_index, - xla::MakeEdgePaddingConfig({{0, buffer_shape_post_axes.size()}})); - - // Slice the i-th index from the updates array. - auto updates_offset = xla::Reshape(i, {1}); - updates_offset = xla::Pad( - updates_offset, zero_index, - xla::MakeEdgePaddingConfig({{0, buffer_shape_post_axes.size()}})); - std::vector flat_updates_slice_shape({1}); - flat_updates_slice_shape.insert(flat_updates_slice_shape.end(), - buffer_shape_post_axes.begin(), - buffer_shape_post_axes.end()); - auto update = - xla::DynamicSlice(updates, updates_offset, flat_updates_slice_shape); - - // Unflatten the major (iteration) dimensions of the slice to their - // original shape. - std::vector updates_slice_shape(num_index_dims, 1); - updates_slice_shape.insert(updates_slice_shape.end(), - buffer_shape_post_axes.begin(), - buffer_shape_post_axes.end()); - update = xla::Reshape(update, updates_slice_shape); - - // Apply the update to the buffer. If there is a combiner, use it to merge - // the current values with the update. - auto current_value = xla::DynamicSlice(buffer, index, updates_slice_shape); + for (int64 i = 0; i < num_index_dims; ++i) { + dim_numbers.add_inserted_window_dims(i); + dim_numbers.add_scatter_dims_to_operand_dims(i); + } + + // Build the combiner computation. + xla::XlaComputation combiner_computation; + { + xla::XlaBuilder cb("scatter-combiner"); + auto xla_scalar_shape = + xla::ShapeUtil::MakeShape(buffer_shape.element_type(), {}); + auto p0 = xla::Parameter(&cb, 0, xla_scalar_shape, "p0"); + auto p1 = xla::Parameter(&cb, 1, xla_scalar_shape, "p1"); if (combiner) { - update = combiner(current_value, update, body_builder); + combiner(p0, p1, &cb); } - // Use the current value instead of the update if the index is out of - // bounds. - update = xla::Select(index_in_range, update, current_value); - // Apply the update. - buffer = xla::DynamicUpdateSlice(buffer, update, index); - - return std::vector{indices, updates, buffer}; - }; - - TF_ASSIGN_OR_RETURN(auto outputs, - XlaForEachIndex(num_indices, indices_shape.element_type(), - body_fn, init, "scatter", builder)); - return outputs[2]; + combiner_computation = cb.Build().ConsumeValueOrDie(); + } + + VLOG(3) << "Scatter op:"; + VLOG(3) << " Input: " << xla::ShapeUtil::HumanString(buffer_shape); + VLOG(3) << " Indices: " << xla::ShapeUtil::HumanString(indices_shape); + VLOG(3) << " Updates: " << xla::ShapeUtil::HumanString(updates_shape); + VLOG(3) << " Scatter Dimension Numbers: "; + VLOG(3) << " index_vector_dim: " << dim_numbers.index_vector_dim(); + VLOG(3) << " update_window_dims: [" + << absl::StrJoin(dim_numbers.update_window_dims(), ",") << "]"; + VLOG(3) << " inserted_window_dims: [" + << absl::StrJoin(dim_numbers.inserted_window_dims(), ",") << "]"; + VLOG(3) << " scatter_dims_to_operand_dims: [" + << absl::StrJoin(dim_numbers.scatter_dims_to_operand_dims(), ",") + << "]"; + + return xla::Scatter(buffer, indices, new_updates, combiner_computation, + dim_numbers); } } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/scatter.h b/tensorflow/compiler/tf2xla/lib/scatter.h index 13a5f1b850a612bddeeac39bef431c19925351ca..4cf478c4b9b4316f1cf43f45d1bf90afa648fb11 100644 --- a/tensorflow/compiler/tf2xla/lib/scatter.h +++ b/tensorflow/compiler/tf2xla/lib/scatter.h @@ -34,7 +34,11 @@ namespace tensorflow { // Otherwise, `indices_are_vectors`, then indices are multidimensional and the // minor dimension of `indices` represents a vector of indices. // -// If any indices are negative, the corresponding update is discarded. +// If `updates` is a scalar, then it will be broadcasted into the expected shape +// of updates. +// +// If any part of the update region is out-of-bounds, the corresponding update +// is discarded. // // If a `combiner` is provided, updates are combined with the existing values in // the buffer using the combiner function. Otherwise, the updates replace the diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc index 37b2240b45b4ae6a587c827cfdfa1096b4e1737e..6524c2a9b1ada632d80edd234272760c2b545cc4 100644 --- a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc +++ b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc @@ -110,9 +110,9 @@ xla::XlaOp DiagonalBlocks(xla::XlaOp a, int64 block_size) { }); } -xla::XlaOp InvertDiagonalBlocks( - xla::XlaOp diag_blocks, bool lower, bool transpose_a, bool conjugate_a, - xla::PrecisionConfigProto::Precision precision) { +xla::XlaOp InvertDiagonalBlocks(xla::XlaOp diag_blocks, bool lower, + bool transpose_a, bool conjugate_a, + xla::PrecisionConfig::Precision precision) { xla::XlaBuilder* builder = diag_blocks.builder(); return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { // Input is a batch of square lower triangular square matrices. Its shape is @@ -216,7 +216,7 @@ xla::XlaOp InvertDiagonalBlocks( dnums.add_rhs_batch_dimensions(0); dnums.add_lhs_contracting_dimensions(2); dnums.add_rhs_contracting_dimensions(1); - xla::PrecisionConfigProto precision_proto; + xla::PrecisionConfig precision_proto; precision_proto.add_operand_precision(precision); precision_proto.add_operand_precision(precision); auto update = -DotGeneral(input_row, body_out, dnums, &precision_proto); @@ -245,7 +245,7 @@ xla::XlaOp InvertDiagonalBlocks( xla::XlaOp SolveWithInvertedDiagonalBlocks( xla::XlaOp a, xla::XlaOp b, xla::XlaOp inv_diag_blocks, bool left_side, bool lower, bool transpose_a, bool conjugate_a, - xla::PrecisionConfigProto::Precision precision) { + xla::PrecisionConfig::Precision precision) { xla::XlaBuilder* builder = a.builder(); return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { TF_ASSIGN_OR_RETURN(xla::Shape blocks_shape, @@ -346,7 +346,7 @@ xla::XlaOp SolveWithInvertedDiagonalBlocks( xla::XlaOp TriangularSolve(xla::XlaOp a, xla::XlaOp b, bool left_side, bool lower, bool transpose_a, bool conjugate_a, int64 block_size, - xla::PrecisionConfigProto::Precision precision) { + xla::PrecisionConfig::Precision precision) { xla::XlaBuilder* builder = a.builder(); return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve.h b/tensorflow/compiler/tf2xla/lib/triangular_solve.h index ac42a4835295b7cb52697710d738f4728d3983d1..2303234f361e54cd2a0ad495cb03b371bed76877 100644 --- a/tensorflow/compiler/tf2xla/lib/triangular_solve.h +++ b/tensorflow/compiler/tf2xla/lib/triangular_solve.h @@ -57,11 +57,10 @@ namespace tensorflow { // // Uses a blocked algorithm if `block_size` is > 1; if block_size == 1 then no // blocking is used. -xla::XlaOp TriangularSolve(xla::XlaOp a, xla::XlaOp b, bool left_side, - bool lower, bool transpose_a, bool conjugate_a, - int64 block_size = 128, - xla::PrecisionConfigProto::Precision precision = - xla::PrecisionConfigProto::HIGHEST); +xla::XlaOp TriangularSolve( + xla::XlaOp a, xla::XlaOp b, bool left_side, bool lower, bool transpose_a, + bool conjugate_a, int64 block_size = 128, + xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::HIGHEST); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/util.cc b/tensorflow/compiler/tf2xla/lib/util.cc index 8b5beba383cda45d36e2ee27ca5e3b3c5988b6b7..804671fbc75b0a5a6e04b204822b6f084013cd8b 100644 --- a/tensorflow/compiler/tf2xla/lib/util.cc +++ b/tensorflow/compiler/tf2xla/lib/util.cc @@ -64,31 +64,31 @@ xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type, xla::Literal literal; switch (type) { case xla::U8: - literal = std::move(*xla::LiteralUtil::CreateR0(value)); + literal = xla::LiteralUtil::CreateR0(value); break; case xla::U32: - literal = std::move(*xla::LiteralUtil::CreateR0(value)); + literal = xla::LiteralUtil::CreateR0(value); break; case xla::U64: - literal = std::move(*xla::LiteralUtil::CreateR0(value)); + literal = xla::LiteralUtil::CreateR0(value); break; case xla::S8: - literal = std::move(*xla::LiteralUtil::CreateR0(value)); + literal = xla::LiteralUtil::CreateR0(value); break; case xla::S32: - literal = std::move(*xla::LiteralUtil::CreateR0(value)); + literal = xla::LiteralUtil::CreateR0(value); break; case xla::S64: - literal = std::move(*xla::LiteralUtil::CreateR0(value)); + literal = xla::LiteralUtil::CreateR0(value); break; case xla::F32: - literal = std::move(*xla::LiteralUtil::CreateR0(value)); + literal = xla::LiteralUtil::CreateR0(value); break; case xla::F64: - literal = std::move(*xla::LiteralUtil::CreateR0(value)); + literal = xla::LiteralUtil::CreateR0(value); break; case xla::C64: - literal = std::move(*xla::LiteralUtil::CreateR0(value)); + literal = xla::LiteralUtil::CreateR0(value); break; case xla::PRED: LOG(FATAL) << "pred element type is not integral"; @@ -96,12 +96,12 @@ xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type, case xla::U16: LOG(FATAL) << "u16/s16 literals not yet implemented"; case xla::BF16: - literal = std::move( - *xla::LiteralUtil::CreateR0(static_cast(value))); + literal = + xla::LiteralUtil::CreateR0(static_cast(value)); break; case xla::F16: - literal = std::move(*xla::LiteralUtil::CreateR0( - static_cast(value))); + literal = + xla::LiteralUtil::CreateR0(static_cast(value)); break; case xla::TUPLE: LOG(FATAL) << "tuple element type is not integral"; @@ -113,8 +113,8 @@ xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type, return xla::ConstantLiteral(builder, literal); } -xla::XlaOp SliceInMinorDims(xla::XlaOp x, gtl::ArraySlice start, - gtl::ArraySlice end) { +xla::XlaOp SliceInMinorDims(xla::XlaOp x, absl::Span start, + absl::Span end) { xla::XlaBuilder* builder = x.builder(); return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { TF_RET_CHECK(start.size() == end.size()); @@ -124,9 +124,10 @@ xla::XlaOp SliceInMinorDims(xla::XlaOp x, gtl::ArraySlice start, const int64 n_dims = xla::ShapeUtil::Rank(shape); TF_RET_CHECK(n_minor_dims <= n_dims); - gtl::ArraySlice major_dims(xla::AsInt64Slice(shape.dimensions()), - /*pos=*/0, - /*len=*/n_dims - n_minor_dims); + auto major_dims = xla::AsInt64Slice(shape.dimensions()) + .subspan( + /*pos=*/0, + /*len=*/n_dims - n_minor_dims); // Prepends 0s in the major dim std::vector padded_start(n_dims, 0); @@ -143,8 +144,8 @@ xla::XlaOp SliceInMinorDims(xla::XlaOp x, gtl::ArraySlice start, }); } -std::vector ConcatVectors(gtl::ArraySlice xs, - gtl::ArraySlice ys) { +std::vector ConcatVectors(absl::Span xs, + absl::Span ys) { std::vector output(xs.size() + ys.size()); std::copy(xs.begin(), xs.end(), output.begin()); std::copy(ys.begin(), ys.end(), output.begin() + xs.size()); @@ -152,8 +153,8 @@ std::vector ConcatVectors(gtl::ArraySlice xs, } xla::XlaOp DynamicSliceInMinorDims(xla::XlaOp x, - gtl::ArraySlice starts, - gtl::ArraySlice sizes) { + absl::Span starts, + absl::Span sizes) { xla::XlaBuilder* builder = x.builder(); return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); @@ -161,9 +162,10 @@ xla::XlaOp DynamicSliceInMinorDims(xla::XlaOp x, int64 n_minor_dims = starts.size(); TF_RET_CHECK(n_minor_dims == sizes.size()); TF_RET_CHECK(n_minor_dims <= n_dims); - gtl::ArraySlice major_dims(xla::AsInt64Slice(shape.dimensions()), - /*pos=*/0, - /*len=*/n_dims - sizes.size()); + auto major_dims = xla::AsInt64Slice(shape.dimensions()) + .subspan( + /*pos=*/0, + /*len=*/n_dims - sizes.size()); auto padded_starts = PrependZerosInMajorDims(x, starts); auto padded_sizes = ConcatVectors(major_dims, sizes); return xla::DynamicSlice(x, padded_starts, padded_sizes); @@ -171,7 +173,7 @@ xla::XlaOp DynamicSliceInMinorDims(xla::XlaOp x, } xla::XlaOp UpdateSlice(xla::XlaOp x, xla::XlaOp update, - gtl::ArraySlice start) { + absl::Span start) { xla::XlaBuilder* builder = x.builder(); return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { // TODO(phawkins): make int64 work on all backends, remove the int32 cast. @@ -189,7 +191,7 @@ xla::XlaOp UpdateSlice(xla::XlaOp x, xla::XlaOp update, } xla::XlaOp UpdateSliceInMinorDims(xla::XlaOp x, xla::XlaOp update, - gtl::ArraySlice start) { + absl::Span start) { xla::XlaBuilder* builder = x.builder(); return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); @@ -204,13 +206,13 @@ xla::XlaOp UpdateSliceInMinorDims(xla::XlaOp x, xla::XlaOp update, } xla::XlaOp DynamicUpdateSliceInMinorDims(xla::XlaOp x, xla::XlaOp update, - gtl::ArraySlice starts) { + absl::Span starts) { auto padded_starts = PrependZerosInMajorDims(x, starts); return xla::DynamicUpdateSlice(x, update, padded_starts); } xla::XlaOp PrependZerosInMajorDims(xla::XlaOp x, - gtl::ArraySlice starts) { + absl::Span starts) { xla::XlaBuilder* builder = x.builder(); return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); diff --git a/tensorflow/compiler/tf2xla/lib/util.h b/tensorflow/compiler/tf2xla/lib/util.h index b4905c952820a45371e090aa98466654e2db9661..80e9e5b002d49581209e608b98606e02709c5876 100644 --- a/tensorflow/compiler/tf2xla/lib/util.h +++ b/tensorflow/compiler/tf2xla/lib/util.h @@ -16,10 +16,10 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_UTIL_H_ #define TENSORFLOW_COMPILER_TF2XLA_LIB_UTIL_H_ +#include "absl/types/span.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/core/lib/gtl/array_slice.h" namespace tensorflow { @@ -31,7 +31,7 @@ xla::XlaOp FloatLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type, // Makes a 1D tensor [0, ..., x, y] from two tensors x and y with zeros // prepended until the array is length n_dims. xla::XlaOp PrependZerosInMajorDims(xla::XlaOp x, - gtl::ArraySlice starts); + absl::Span starts); // Returns a integer scalar constant of 'type' with 'value'. // If 'type' is complex, returns a real value with zero imaginary component. @@ -41,33 +41,33 @@ xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type, // Builds a vector of zeros of length rank(x) with the last values being // those in `starts`. xla::XlaOp PrependZerosInMajorDims(xla::XlaOp x, - gtl::ArraySlice starts); + absl::Span starts); // Performs a slice in the minor dimensions of a Tensor. -xla::XlaOp SliceInMinorDims(xla::XlaOp x, gtl::ArraySlice start, - gtl::ArraySlice end); +xla::XlaOp SliceInMinorDims(xla::XlaOp x, absl::Span start, + absl::Span end); // Returns the concatenation of `xs` and `ys`. -std::vector ConcatVectors(gtl::ArraySlice xs, - gtl::ArraySlice ys); +std::vector ConcatVectors(absl::Span xs, + absl::Span ys); // Performs a dynamic slice in the minor dimensions of a Tensor. xla::XlaOp DynamicSliceInMinorDims(xla::XlaOp x, - gtl::ArraySlice starts, - gtl::ArraySlice sizes); + absl::Span starts, + absl::Span sizes); // Updates a slice of 'x', i.e., // x[start[0], ..., start[n]] = update xla::XlaOp UpdateSlice(xla::XlaOp x, xla::XlaOp update, - gtl::ArraySlice start); + absl::Span start); // Updates a slice of 'x', where 'start' contains a list of minor dimensions: // x[..., start[0], ..., start[n]] = update xla::XlaOp UpdateSliceInMinorDims(xla::XlaOp x, xla::XlaOp update, - gtl::ArraySlice start); + absl::Span start); xla::XlaOp DynamicUpdateSliceInMinorDims(xla::XlaOp x, xla::XlaOp update, - gtl::ArraySlice starts); + absl::Span starts); // Transposes a stack of matrices `x` by swapping the last two dimensions. xla::XlaOp TransposeInMinorDims(xla::XlaOp x); diff --git a/tensorflow/compiler/tf2xla/lib/while_loop.cc b/tensorflow/compiler/tf2xla/lib/while_loop.cc index d64394f1401d7ceea004a59c991ef6f4a1c58b41..594ab1dfd0700f47501712183f6efe62d17e15e7 100644 --- a/tensorflow/compiler/tf2xla/lib/while_loop.cc +++ b/tensorflow/compiler/tf2xla/lib/while_loop.cc @@ -24,7 +24,7 @@ namespace tensorflow { xla::StatusOr> XlaWhileLoop( const LoopConditionFunction& condition_function, const LoopBodyFunction& body_function, - gtl::ArraySlice initial_values, StringPiece name, + absl::Span initial_values, absl::string_view name, xla::XlaBuilder* builder) { int arity = initial_values.size(); std::vector var_shapes; @@ -47,7 +47,7 @@ xla::StatusOr> XlaWhileLoop( // Build the condition. std::unique_ptr cond_builder = - builder->CreateSubBuilder(strings::StrCat(name, "_condition")); + builder->CreateSubBuilder(absl::StrCat(name, "_condition")); { auto parameter = xla::Parameter(cond_builder.get(), 0, tuple_shape, "parameter"); @@ -61,7 +61,7 @@ xla::StatusOr> XlaWhileLoop( // Build the body. std::unique_ptr body_builder = - builder->CreateSubBuilder(strings::StrCat(name, "_body")); + builder->CreateSubBuilder(absl::StrCat(name, "_body")); { auto parameter = xla::Parameter(body_builder.get(), 0, tuple_shape, "parameter"); @@ -84,15 +84,15 @@ xla::StatusOr> XlaWhileLoop( xla::StatusOr> XlaForEachIndex( int64 num_iterations, xla::PrimitiveType num_iterations_type, const ForEachIndexBodyFunction& body_function, - gtl::ArraySlice initial_values, StringPiece name, + absl::Span initial_values, absl::string_view name, xla::XlaBuilder* builder) { auto while_cond_fn = - [&](gtl::ArraySlice values, + [&](absl::Span values, xla::XlaBuilder* cond_builder) -> xla::StatusOr { return xla::Lt(values[0], IntegerLiteral(cond_builder, num_iterations_type, num_iterations)); }; - auto while_body_fn = [&](gtl::ArraySlice values, + auto while_body_fn = [&](absl::Span values, xla::XlaBuilder* body_builder) -> xla::StatusOr> { xla::XlaOp iteration = values[0]; diff --git a/tensorflow/compiler/tf2xla/lib/while_loop.h b/tensorflow/compiler/tf2xla/lib/while_loop.h index 9493b1f109be0725f7f733b9f9da664264275a69..f2134bb4495a12b8342961d96f70e7737f816c7d 100644 --- a/tensorflow/compiler/tf2xla/lib/while_loop.h +++ b/tensorflow/compiler/tf2xla/lib/while_loop.h @@ -19,24 +19,24 @@ limitations under the License. #include #include +#include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/core/lib/core/stringpiece.h" -#include "tensorflow/core/lib/gtl/array_slice.h" namespace tensorflow { // Function that builds a loop condition. Takes as input a sequence of input // values, and returns a boolean value representing if the condition succeeds. -typedef std::function(gtl::ArraySlice, +typedef std::function(absl::Span, xla::XlaBuilder*)> LoopConditionFunction; // Function that builds a loop body. Takes as input a sequence of input values // and returns a sequence of output values. typedef std::function>( - gtl::ArraySlice, xla::XlaBuilder*)> + absl::Span, xla::XlaBuilder*)> LoopBodyFunction; // Helper function for building an XLA while loop, where the values carried by @@ -50,7 +50,7 @@ typedef std::function>( xla::StatusOr> XlaWhileLoop( const LoopConditionFunction& condition_function, const LoopBodyFunction& body_function, - gtl::ArraySlice initial_values, StringPiece name, + absl::Span initial_values, absl::string_view name, xla::XlaBuilder* builder); // Builds an XLA loop that repeats a computation `num_iterations` times. @@ -59,13 +59,13 @@ xla::StatusOr> XlaWhileLoop( // (current iteration number, loop-carried values), and returns an updated // vector of the loop-carried values. typedef std::function>( - xla::XlaOp, gtl::ArraySlice, xla::XlaBuilder*)> + xla::XlaOp, absl::Span, xla::XlaBuilder*)> ForEachIndexBodyFunction; xla::StatusOr> XlaForEachIndex( int64 num_iterations, xla::PrimitiveType num_iterations_type, const ForEachIndexBodyFunction& body_function, - gtl::ArraySlice initial_values, StringPiece name, + absl::Span initial_values, absl::string_view name, xla::XlaBuilder* builder); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/literal_util.cc b/tensorflow/compiler/tf2xla/literal_util.cc index 77da1bf29ced60e490f07abad41cf8ce96232982..20103ec3ae00b57723e05326dbbb1b0f6e1a671a 100644 --- a/tensorflow/compiler/tf2xla/literal_util.cc +++ b/tensorflow/compiler/tf2xla/literal_util.cc @@ -49,9 +49,8 @@ Status HostTensorToMutableBorrowingLiteral( return Status::OK(); } -Status HostTensorsToBorrowingLiteralTuple( - tensorflow::gtl::ArraySlice host_tensors, - xla::BorrowingLiteral* literal) { +Status HostTensorsToBorrowingLiteralTuple(absl::Span host_tensors, + xla::BorrowingLiteral* literal) { std::vector buf_ptrs; buf_ptrs.reserve(host_tensors.size()); std::vector tensor_shapes(host_tensors.size()); diff --git a/tensorflow/compiler/tf2xla/literal_util.h b/tensorflow/compiler/tf2xla/literal_util.h index 09d6fa811669b422532673540e4da47f47e6be4e..1db7470ee2a839099454b772d4833492e033bc92 100644 --- a/tensorflow/compiler/tf2xla/literal_util.h +++ b/tensorflow/compiler/tf2xla/literal_util.h @@ -18,11 +18,11 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_LITERAL_UTIL_H_ #define TENSORFLOW_COMPILER_TF2XLA_LITERAL_UTIL_H_ +#include "absl/types/span.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/array_slice.h" namespace tensorflow { @@ -43,9 +43,8 @@ Status HostTensorToMutableBorrowingLiteral( // Returns a BorrowingLiteral tuple that utilizes the same underlying buffers // owned by 'host_tensors'. -Status HostTensorsToBorrowingLiteralTuple( - tensorflow::gtl::ArraySlice host_tensors, - xla::BorrowingLiteral* literal); +Status HostTensorsToBorrowingLiteralTuple(absl::Span host_tensors, + xla::BorrowingLiteral* literal); // Copies 'literal' to freshly allocated 'host_tensor', which is allocated of // type . diff --git a/tensorflow/compiler/tf2xla/literal_util_test.cc b/tensorflow/compiler/tf2xla/literal_util_test.cc index a3404c2b3df7bf25011359d1f5f5b88c29a3f83b..15f4c38da29507da9e092c1d5725b5f95a81d1b9 100644 --- a/tensorflow/compiler/tf2xla/literal_util_test.cc +++ b/tensorflow/compiler/tf2xla/literal_util_test.cc @@ -22,51 +22,61 @@ limitations under the License. #include "tensorflow/core/platform/test.h" namespace tensorflow { +namespace { TEST(LiteralUtil, LiteralToHostTensor) { // int64 literal can only be converted to an int64 host tensor. - { - std::vector int64_values = {1, 2, 3}; - std::unique_ptr int64_values_literal = - xla::LiteralUtil::CreateR1(gtl::ArraySlice(int64_values)); - Tensor host_tensor; - EXPECT_EQ("Cannot convert literal of type S64 to tensor of type int32", - LiteralToHostTensor(*int64_values_literal, DT_INT32, &host_tensor) - .error_message()); - EXPECT_EQ( - "Cannot convert literal of type S64 to tensor of type qint32", - LiteralToHostTensor(*int64_values_literal, DT_QINT32, &host_tensor) - .error_message()); - EXPECT_TRUE( - LiteralToHostTensor(*int64_values_literal, DT_INT64, &host_tensor) - .ok()); - test::ExpectTensorEqual(host_tensor, - test::AsTensor(int64_values)); - } + std::vector int64_values = {1, 2, 3}; + xla::Literal int64_values_literal = + xla::LiteralUtil::CreateR1(absl::Span(int64_values)); + Tensor host_tensor; + EXPECT_EQ("Cannot convert literal of type S64 to tensor of type int32", + LiteralToHostTensor(int64_values_literal, DT_INT32, &host_tensor) + .error_message()); + EXPECT_EQ("Cannot convert literal of type S64 to tensor of type qint32", + LiteralToHostTensor(int64_values_literal, DT_QINT32, &host_tensor) + .error_message()); + EXPECT_TRUE( + LiteralToHostTensor(int64_values_literal, DT_INT64, &host_tensor).ok()); + test::ExpectTensorEqual(host_tensor, + test::AsTensor(int64_values)); +} + +template +using LiteralUtilTest = ::testing::Test; +using Types = + ::testing::Types, std::pair, + std::pair, std::pair, + std::pair>; + +TYPED_TEST_CASE(LiteralUtilTest, Types); + +TYPED_TEST(LiteralUtilTest, LiteralToQuantizedHostTensor) { + using int_type = typename TypeParam::first_type; + using qint_type = typename TypeParam::second_type; - { - // Repeat tests with int32. - Tensor host_tensor; - std::vector int32_values = {10, 11}; - std::unique_ptr int32_values_literal = - xla::LiteralUtil::CreateR1(gtl::ArraySlice(int32_values)); - EXPECT_TRUE( - LiteralToHostTensor(*int32_values_literal, DT_INT32, &host_tensor) - .ok()); - test::ExpectTensorEqual(host_tensor, - test::AsTensor(int32_values)); + Tensor host_tensor; + std::vector int_values = {10, 11}; + xla::Literal int_values_literal = + xla::LiteralUtil::CreateR1(absl::Span(int_values)); + EXPECT_TRUE(LiteralToHostTensor(int_values_literal, + DataTypeToEnum::value, &host_tensor) + .ok()); + test::ExpectTensorEqual(host_tensor, + test::AsTensor(int_values)); - EXPECT_TRUE( - LiteralToHostTensor(*int32_values_literal, DT_QINT32, &host_tensor) - .ok()); - std::vector qint32_values = {10, 11}; - test::ExpectTensorEqual(host_tensor, - test::AsTensor(qint32_values)); + EXPECT_TRUE(LiteralToHostTensor(int_values_literal, + DataTypeToEnum::value, + &host_tensor) + .ok()); + std::vector qint_values = {10, 11}; + test::ExpectTensorEqual(host_tensor, + test::AsTensor(qint_values)); - EXPECT_EQ("Cannot convert literal of type S32 to tensor of type int64", - LiteralToHostTensor(*int32_values_literal, DT_INT64, &host_tensor) - .error_message()); - } + EXPECT_EQ( + error::INVALID_ARGUMENT, + LiteralToHostTensor(int_values_literal, DT_INT64, &host_tensor).code()); } +} // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/ops/xla_ops.cc b/tensorflow/compiler/tf2xla/ops/xla_ops.cc index 2cd9ae799f06afdcbae5429ef8caffd3b4d29c29..bd2c0a5ee88869ba60701c0a7ace05857452eed9 100644 --- a/tensorflow/compiler/tf2xla/ops/xla_ops.cc +++ b/tensorflow/compiler/tf2xla/ops/xla_ops.cc @@ -83,7 +83,7 @@ lhs_dilation: dilation to apply between input elements rhs_dilation: dilation to apply between kernel elements feature_group_count: number of feature groups for grouped convolution. dimension_numbers: a serialized xla::ConvolutionDimensionNumbers proto. -precision_config: a serialized xla::PrecisionConfigProto proto. +precision_config: a serialized xla::PrecisionConfig proto. )doc"); REGISTER_OP("XlaDot") @@ -102,7 +102,37 @@ Wraps the XLA ConvGeneralDilated operator, documented at lhs: the LHS tensor rhs: the RHS tensor dimension_numbers: a serialized xla::DotDimensionNumbers proto. -precision_config: a serialized xla::PrecisionConfigProto proto. +precision_config: a serialized xla::PrecisionConfig proto. +)doc"); + +REGISTER_OP("XlaDynamicSlice") + .Input("input: T") + .Input("start_indices: Tindices") + .Input("size_indices: Tindices") + .Output("output: T") + .Attr("T: type") + .Attr("Tindices: {int32, int64}") + .SetShapeFn(shape_inference::UnknownShape) + .Doc(R"doc( +Wraps the XLA DynamicSlice operator, documented at + https://www.tensorflow.org/performance/xla/operation_semantics#dynamicslice +. + +DynamicSlice extracts a sub-array from the input array at dynamic +start_indices. The size of the slice in each dimension is passed in +size_indices, which specify the end point of exclusive slice intervals in each +dimension -- [start, start + size). The shape of start_indices must have rank 1, +with dimension size equal to the rank of operand. + +input: A `Tensor` of type T. + +start_indices: Rank 1 tensor of N integers containing the starting indices of + the slice for each dimension. Value must be greater than or equal to zero. + +start_indices: List of N integers containing the slice size for each + dimension. Each value must be strictly greater than zero, and start + size + must be less than or equal to the size of the dimension to avoid + implementation defined behavior. )doc"); REGISTER_OP("XlaDynamicUpdateSlice") @@ -253,6 +283,8 @@ REGISTER_OP("XlaReduceWindow") .Input("init_value: T") .Input("window_dimensions: Tindices") .Input("window_strides: Tindices") + .Input("base_dilations: Tindices") + .Input("window_dilations: Tindices") .Input("padding: Tindices") .Attr("T: numbertype") .Attr("Tindices: {int32, int64}") @@ -324,12 +356,33 @@ Wraps the XLA Sort operator, documented at https://www.tensorflow.org/performance/xla/operation_semantics#sort . -Sorts a tensor. Currently only rank 1 sorts in ascending order are supported. +Sorts a tensor. Currently only sorts in ascending order are supported. input: A `Tensor` of type T. output: A `Tensor` of type T. )doc"); +REGISTER_OP("XlaKeyValueSort") + .Input("keys: K") + .Input("values: V") + .Output("sorted_keys: K") + .Output("sorted_values: V") + .Attr("K: realnumbertype") + .Attr("V: type") + .SetShapeFn(shape_inference::UnchangedShape) + .Doc(R"doc( +Wraps the XLA Sort operator, documented at + https://www.tensorflow.org/performance/xla/operation_semantics#sort +. + +Sorts a tensor. Currently only sorts in ascending order are supported. + +keys: A `Tensor` of type K. +values: A `Tensor` of type V. +sorted_keys: A `Tensor` of type K. +sorted_values: A `Tensor` of type V. +)doc"); + // TODO(b/37549631) setting the While Op to always be stateful is too // conservative. REGISTER_OP("XlaWhile") diff --git a/tensorflow/compiler/tf2xla/python/xla.py b/tensorflow/compiler/tf2xla/python/xla.py index 3626de375ea9ac12e40ea5b5b591bb6d5262adbc..5e86b5d8ec0a2690f004bc67decea09185d9cbb6 100644 --- a/tensorflow/compiler/tf2xla/python/xla.py +++ b/tensorflow/compiler/tf2xla/python/xla.py @@ -212,9 +212,9 @@ bitcast_convert_type = array_ops.bitcast def broadcast(x, dims, name=None): x = ops.convert_to_tensor(x) - shape = array_ops.concat( - [constant_op.constant(dims), - array_ops.shape(x)], axis=0) + shape = array_ops.concat([constant_op.constant(dims), + array_ops.shape(x)], + axis=0) return array_ops.broadcast_to(x, shape, name=name) @@ -291,13 +291,7 @@ def dot_general(lhs, rhs, dimension_numbers, precision_config=None, name=None): name=name) -def dynamic_slice(x, starts, sizes, name=None): - # TODO(phawkins): the Slice operator lowers to DynamicSlice if `starts` is not - # a compile-time constant. This doesn't exactly mimic the semantics of dynamic - # slice if the slice is out of bounds. - return array_ops.slice(x, starts, sizes, name=name) - - +dynamic_slice = gen_xla_ops.xla_dynamic_slice dynamic_update_slice = gen_xla_ops.xla_dynamic_update_slice # TODO(phawkins): generalize tf.pad to support interior padding, and then remove @@ -326,6 +320,8 @@ def reduce_window(operand, reducer, window_dimensions, window_strides=None, + base_dilations=None, + window_dilations=None, padding=None, name=None): """Wraps the XLA ReduceWindow operator. @@ -338,22 +334,27 @@ def reduce_window(operand, init: a scalar tensor representing the initial value for the reduction reducer: a reduction function that combines a pair of scalars. window_dimensions: shape of the window, as a list of integers - window_strides: inter-window strides, as a list of integers. Optional; - if omitted, defaults to strides of 1. + window_strides: inter-window strides, as a list of integers. Optional; if + omitted, defaults to strides of 1. padding: padding to apply to 'operand'. List of (low, high) pairs of integers that specify the padding to apply before and after each dimension. Optional; if omitted, defaults to no padding. name: the operator name, or None. + Returns: A tensor that represents the output of the reduce_window operator. """ window_strides = window_strides or [1] * len(window_dimensions) + base_dilations = base_dilations or [1] * len(window_dimensions) + window_dilations = window_dilations or [1] * len(window_dimensions) padding = padding or [(0, 0)] * len(window_dimensions) return gen_xla_ops.xla_reduce_window( input=operand, init_value=init, window_dimensions=window_dimensions, window_strides=window_strides, + base_dilations=base_dilations, + window_dilations=window_dilations, padding=padding, computation=reducer, name=name) @@ -383,4 +384,5 @@ def slice(x, start_dims, limit_dims, strides): sort = gen_xla_ops.xla_sort +key_value_sort = gen_xla_ops.xla_key_value_sort while_loop = gen_xla_ops.xla_while diff --git a/tensorflow/compiler/tf2xla/resource_operation_table.cc b/tensorflow/compiler/tf2xla/resource_operation_table.cc index 32ba6df2e6daa2add468a1bc0559d42606d1a9a6..72b240996fb4d9dcb5f5dfd919da618cbae08c16 100644 --- a/tensorflow/compiler/tf2xla/resource_operation_table.cc +++ b/tensorflow/compiler/tf2xla/resource_operation_table.cc @@ -15,10 +15,10 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/resource_operation_table.h" #include "absl/algorithm/container.h" -#include "tensorflow/core/lib/gtl/flatmap.h" +#include "absl/container/flat_hash_map.h" namespace tensorflow { -/*static*/ StringPiece XlaResourceOpInfo::XlaResourceOpKindToString( +/*static*/ absl::string_view XlaResourceOpInfo::XlaResourceOpKindToString( XlaResourceOpKind op_kind) { switch (op_kind) { case XlaResourceOpKind::kRead: @@ -30,11 +30,11 @@ namespace tensorflow { } } -static gtl::FlatMap* CreateResourceOpInfoMap() { - gtl::FlatMap* result = - new gtl::FlatMap; +static absl::flat_hash_map* +CreateResourceOpInfoMap() { + auto* result = new absl::flat_hash_map; - auto add = [&](StringPiece op, XlaResourceOpKind op_kind, + auto add = [&](absl::string_view op, XlaResourceOpKind op_kind, XlaResourceKind resource_kind) { auto insert_result = result->insert({op, XlaResourceOpInfo(op_kind, resource_kind)}); @@ -103,23 +103,23 @@ static gtl::FlatMap* CreateResourceOpInfoMap() { return result; } -static const gtl::FlatMap& +static const absl::flat_hash_map& GetStaticResourceOpInfoMap() { - static gtl::FlatMap* op_info_map = - CreateResourceOpInfoMap(); + static absl::flat_hash_map* + op_info_map = CreateResourceOpInfoMap(); return *op_info_map; } -const XlaResourceOpInfo* GetResourceOpInfoForOp(StringPiece op) { - const gtl::FlatMap& op_infos = +const XlaResourceOpInfo* GetResourceOpInfoForOp(absl::string_view op) { + const absl::flat_hash_map& op_infos = GetStaticResourceOpInfoMap(); auto it = op_infos.find(op); return it == op_infos.end() ? nullptr : &it->second; } namespace resource_op_table_internal { -std::vector GetKnownResourceOps() { - std::vector result; +std::vector GetKnownResourceOps() { + std::vector result; for (const auto& p : GetStaticResourceOpInfoMap()) { result.push_back(p.first); } diff --git a/tensorflow/compiler/tf2xla/resource_operation_table.h b/tensorflow/compiler/tf2xla/resource_operation_table.h index 7f627a64c6e8298a427cd87d25d4ba24835bf542..61c7a56ff0d4adb75e93ced3155b37102763c652 100644 --- a/tensorflow/compiler/tf2xla/resource_operation_table.h +++ b/tensorflow/compiler/tf2xla/resource_operation_table.h @@ -19,7 +19,7 @@ limitations under the License. #include #include -#include "tensorflow/core/lib/core/stringpiece.h" +#include "absl/strings/string_view.h" #include "tensorflow/core/platform/logging.h" // Exposes information about the resource operations supported by tf2xla in a @@ -47,7 +47,7 @@ class XlaResourceOpInfo { XlaResourceOpKind kind() const { return op_kind_; } XlaResourceKind resource_kind() const { return resource_kind_; } - static StringPiece XlaResourceOpKindToString(XlaResourceOpKind op_kind); + static absl::string_view XlaResourceOpKindToString(XlaResourceOpKind op_kind); private: XlaResourceOpKind op_kind_; @@ -57,13 +57,13 @@ class XlaResourceOpInfo { // Returns a XlaResourceOpInfo describing `op` if it is a resource operation // supported by tf2xla, otherwise returns null (i.e. if this returns null then // `op` is either not a resource operation or is unsupported by XLA). -const XlaResourceOpInfo* GetResourceOpInfoForOp(StringPiece op); +const XlaResourceOpInfo* GetResourceOpInfoForOp(absl::string_view op); namespace resource_op_table_internal { // NB! Implementation detail exposed for unit testing, do not use. // // Returns the set of resource operations known by this module. -std::vector GetKnownResourceOps(); +std::vector GetKnownResourceOps(); } // namespace resource_op_table_internal } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/resource_operation_table_test.cc b/tensorflow/compiler/tf2xla/resource_operation_table_test.cc index 0343f80de9fed114a0097b981233277c3e12b378..956f597301d28d781a9df7ab2086ed79d4c8bf9d 100644 --- a/tensorflow/compiler/tf2xla/resource_operation_table_test.cc +++ b/tensorflow/compiler/tf2xla/resource_operation_table_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/resource_operation_table.h" #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" #include "absl/strings/str_join.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -33,8 +34,8 @@ bool HasResourceInputOrOutput(const OpDef& op_def) { } TEST(ResourceOperationTableTest, HaveAllResourceOps) { - gtl::FlatMap known_resource_ops; - for (StringPiece known_resource_op : + absl::flat_hash_map known_resource_ops; + for (absl::string_view known_resource_op : resource_op_table_internal::GetKnownResourceOps()) { ASSERT_TRUE( known_resource_ops.insert({string(known_resource_op), false}).second); diff --git a/tensorflow/compiler/tf2xla/shape_util.cc b/tensorflow/compiler/tf2xla/shape_util.cc index 9d1992205b02665b99b1bd15b7b65a1fb8c35a51..b589512dcdfa32050281120aba6a5ae89a980c2f 100644 --- a/tensorflow/compiler/tf2xla/shape_util.cc +++ b/tensorflow/compiler/tf2xla/shape_util.cc @@ -41,6 +41,14 @@ Status XLAShapeToTensorShape(const xla::Shape& shape, // Convert a TensorShape into the equivalent XLA Shape proto. Status TensorShapeToXLAShape(DataType dtype, const TensorShape& tensor_shape, xla::Shape* shape) { + xla::PrimitiveType type; + TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(dtype, &type)); + *shape = TensorShapeToXLAShape(type, tensor_shape); + return Status::OK(); +} + +xla::Shape TensorShapeToXLAShape(xla::PrimitiveType type, + const TensorShape& tensor_shape) { int rank = tensor_shape.dims(); std::vector dimensions(rank); std::vector layout(rank); @@ -50,11 +58,7 @@ Status TensorShapeToXLAShape(DataType dtype, const TensorShape& tensor_shape, // XLA uses minor-to-major; Tensorflow uses major-to-minor. std::iota(layout.rbegin(), layout.rend(), 0); - xla::PrimitiveType type; - TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(dtype, &type)); - - *shape = xla::ShapeUtil::MakeShapeWithLayout(type, dimensions, layout); - return Status::OK(); + return xla::ShapeUtil::MakeShapeWithLayout(type, dimensions, layout); } } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/shape_util.h b/tensorflow/compiler/tf2xla/shape_util.h index 58240b9c965a194b9380ac7cd477ce7344e5ebe3..f7e34a5b40c2f9244c029ed325a76322b8cf54dd 100644 --- a/tensorflow/compiler/tf2xla/shape_util.h +++ b/tensorflow/compiler/tf2xla/shape_util.h @@ -35,6 +35,11 @@ Status XLAShapeToTensorShape(const xla::Shape& shape, Status TensorShapeToXLAShape(DataType dtype, const TensorShape& tensor_shape, xla::Shape* shape); +// Converts a TensorShape into the equivalent XLA Shape proto, taking an +// xla::PrimitiveType to specify the element type. This never fails. +xla::Shape TensorShapeToXLAShape(xla::PrimitiveType type, + const TensorShape& tensor_shape); + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_TF2XLA_SHAPE_UTIL_H_ diff --git a/tensorflow/compiler/tf2xla/sharding_util.cc b/tensorflow/compiler/tf2xla/sharding_util.cc index 2d7eb8b915b8245ba6573c30b2eb15b12fc3a1b4..8aae498be1042b5a55e849a03d438cd54dafca83 100644 --- a/tensorflow/compiler/tf2xla/sharding_util.cc +++ b/tensorflow/compiler/tf2xla/sharding_util.cc @@ -17,7 +17,6 @@ limitations under the License. #include "absl/strings/match.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/util/device_name_utils.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/side_effect_util.cc b/tensorflow/compiler/tf2xla/side_effect_util.cc new file mode 100644 index 0000000000000000000000000000000000000000..b233e6b2c28e1968bb74901fc684e808ae45ab60 --- /dev/null +++ b/tensorflow/compiler/tf2xla/side_effect_util.cc @@ -0,0 +1,92 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/side_effect_util.h" + +#include "absl/strings/numbers.h" +#include "tensorflow/core/graph/algorithm.h" + +namespace tensorflow { + +const char kXlaTokenInputNodesAttrName[] = "_xla_token_input_nodes"; + +const char kXlaTokenArgNodeName[] = "_xla_token_arg_node"; + +std::set CalculateTokenInputsForOutputToken(const Graph& g) { + std::set results; + Node* first_side_effecting_node_on_path = nullptr; + ReverseDFS(g, + [&](Node* n) { + std::vector token_input_nodes; + if (!GetNodeAttr(n->attrs(), kXlaTokenInputNodesAttrName, + &token_input_nodes) + .ok() || + token_input_nodes.empty()) { + return; + } + + if (first_side_effecting_node_on_path != nullptr) { + return; + } + + first_side_effecting_node_on_path = n; + results.insert(n->name()); + }, + [&](Node* n) { + if (first_side_effecting_node_on_path == n) { + first_side_effecting_node_on_path = nullptr; + } + }, + NodeComparatorName()); + return results; +} + +bool HasSideEffectingNodes(const Graph& g) { + for (Node* n : g.nodes()) { + std::vector token_input_nodes; + if (GetNodeAttr(n->attrs(), kXlaTokenInputNodesAttrName, &token_input_nodes) + .ok() && + !token_input_nodes.empty()) { + return true; + } + } + return false; +} + +Status ParseHostComputeCoreList(absl::Span list_from_attr, + std::map* host_compute_core) { + for (const auto& hc_core : list_from_attr) { + std::vector parts = str_util::Split(hc_core, ":"); + if (parts.size() != 2) { + return errors::InvalidArgument( + "Malformed host_compute_core entry ", hc_core, + " should be :."); + } + int core; + if (!absl::numbers_internal::safe_strto32_base(parts[1], &core, 10)) { + return errors::InvalidArgument("Malformed host_compute_core entry ", + hc_core, + " part after ':' should be an integer."); + } + if (host_compute_core->find(parts[0]) != host_compute_core->end()) { + return errors::InvalidArgument( + "Duplicate host_compute_core entry for cluster ", parts[0]); + } + (*host_compute_core)[parts[0]] = core; + } + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/side_effect_util.h b/tensorflow/compiler/tf2xla/side_effect_util.h new file mode 100644 index 0000000000000000000000000000000000000000..f22ddb2f58e1fa5c10ca0fdb956d9136942388b7 --- /dev/null +++ b/tensorflow/compiler/tf2xla/side_effect_util.h @@ -0,0 +1,53 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_SIDE_EFFECT_UTIL_H_ +#define TENSORFLOW_COMPILER_TF2XLA_SIDE_EFFECT_UTIL_H_ + +#include + +#include "tensorflow/core/graph/graph.h" + +namespace tensorflow { + +// Side-effecting nodes will have this attribute set. Its value is the list of +// node names which this node has side-effect dependencies on. +// +// Nodes like HostCompute, SendToHost, RecvFromHost always have this attribute, +// because they always have side-effect. +// If and While nodes may or may not have this attribute, depending on whether +// their bodies have side-effecting nodes. +extern const char kXlaTokenInputNodesAttrName[]; + +// This node name is used in kXlaTokenInputNodesAttrName attr to signal that a +// node has side-effect dependency on current graph's token input. +extern const char kXlaTokenArgNodeName[]; + +// Calculates side-effect dependencies for the graph's token output. +// Returns a set of node names representing these dependencies. +std::set CalculateTokenInputsForOutputToken(const Graph& g); + +// Returns whether a graph contains side-effecting nodes. +bool HasSideEffectingNodes(const Graph& g); + +// Parse the mapping from outside_compilation_subgraph name to core number, +// which is specified in an attr as a list of strings +// :. +Status ParseHostComputeCoreList(absl::Span list_from_attr, + std::map* host_compute_core); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_SIDE_EFFECT_UTIL_H_ diff --git a/tensorflow/compiler/tf2xla/test_util.h b/tensorflow/compiler/tf2xla/test_util.h index e6e4ae92ed23f3fca0f59b131dc73152e0947b72..4ffc94ae3bc7c930720cd625a7856443c77be666 100644 --- a/tensorflow/compiler/tf2xla/test_util.h +++ b/tensorflow/compiler/tf2xla/test_util.h @@ -24,8 +24,10 @@ limitations under the License. #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/graph_def_util.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/util/equal_graph_def.h" namespace tensorflow { @@ -44,4 +46,15 @@ Status InstantiateFunctionForTest(const string& name, } // namespace tensorflow +// Variant of TF_EXPECT_GRAPH_EQ that also compares internal attributes for +// equality. +#define TF_EXPECT_GRAPH_EQ_INTERNAL(expected, actual) \ + do { \ + string diff; \ + EqualGraphDefOptions eq_options; \ + eq_options.ignore_internal_attrs = false; \ + EXPECT_TRUE(EqualGraphDef(actual, expected, &diff, eq_options)) \ + << diff << "\nActual: " << SummarizeGraphDef(actual); \ + } while (false) + #endif // TENSORFLOW_COMPILER_TF2XLA_TEST_UTIL_H_ diff --git a/tensorflow/compiler/tf2xla/tf2xla.cc b/tensorflow/compiler/tf2xla/tf2xla.cc index f34af2d67debe8bfa4abcad19e42c55ea40c4e82..b22d53805d83069052cc5e16020d6c540d618a82 100644 --- a/tensorflow/compiler/tf2xla/tf2xla.cc +++ b/tensorflow/compiler/tf2xla/tf2xla.cc @@ -22,8 +22,10 @@ limitations under the License. #include #include +#include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" +#include "tensorflow/compiler/tf2xla/functionalize_control_flow.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" @@ -41,7 +43,6 @@ limitations under the License. #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -75,7 +76,7 @@ Status AddArgNodes(Graph* graph, const NodeMap& node_map, auto node_it = node_map.find(remap_it->second); if (node_it == node_map.end()) { // Strip off the aot_feed_#/ prefix. - StringPiece name(remap_it->second); + absl::string_view name(remap_it->second); const auto index = name.find('/'); if (index > 0) name.remove_prefix(index + 1); return errors::InvalidArgument( @@ -89,7 +90,7 @@ Status AddArgNodes(Graph* graph, const NodeMap& node_map, // explicitly specify or override them. Node* arg_node = nullptr; TF_RETURN_IF_ERROR( - NodeBuilder(strings::StrCat("_arg_", arg_index), kArgOp) + NodeBuilder(absl::StrCat("_arg_", arg_index), kArgOp) .Attr("T", BaseType(feed_node->output_type(output_index))) .Attr("index", arg_index) .Attr(kFeedIdAttr, TensorIdToString(feed.id())) @@ -136,7 +137,7 @@ Status AddRetvalNodes(Graph* graph, const NodeMap& node_map, // Connects fetch_node -> retval_node. Node* retval_node = nullptr; TF_RETURN_IF_ERROR( - NodeBuilder(strings::StrCat("_retval_", ret_index), kRetvalOp) + NodeBuilder(absl::StrCat("_retval_", ret_index), kRetvalOp) .Input(fetch_node, id.output_index()) .Attr("T", BaseType(fetch_node->output_type(id.output_index()))) .Attr("index", ret_index) @@ -256,7 +257,7 @@ Status ConvertGraphToXla(std::unique_ptr graph, xla::Client* client, XlaOpRegistry::RegisterCompilationKernels(); for (Node* node : graph->nodes()) { node->set_assigned_device_name( - strings::StrCat("/device:", DEVICE_CPU_XLA_JIT)); + absl::StrCat("/device:", DEVICE_CPU_XLA_JIT)); } std::vector xla_args; TF_RETURN_IF_ERROR(CreateXlaArgs(*graph, &xla_args)); @@ -340,6 +341,13 @@ Status InitGraph(const GraphDef& graph_def, const tf2xla::Config& config, TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(GraphConstructorOptions(), second_copy_def, g.get())); TF_RETURN_IF_ERROR(RewriteAndPruneGraph(g.get(), config, feed_remapping)); + + // Functionalize control flow. + TF_RETURN_IF_ERROR(FunctionalizeControlFlow(g.get(), &flib_def)); + // After control flow functionalization, we might have more FunctionDef's + // (then/else branch, loop body). Add them to the graph. + TF_RETURN_IF_ERROR(g->AddFunctionLibrary(flib_def.ToProto())); + *graph = std::move(g); return Status::OK(); } diff --git a/tensorflow/compiler/tf2xla/tf2xla_test.cc b/tensorflow/compiler/tf2xla/tf2xla_test.cc index 56f7045a98201ed398244f9e3f5ff23788135b75..ab26d939ccba75ce58609ffd71c7ccadbe90cfa8 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_test.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_test.cc @@ -77,8 +77,8 @@ TEST(ConvertGraphDefToXla, Sum) { // Set up arguments. auto x_literal = xla::LiteralUtil::CreateR0(10); auto y_literal = xla::LiteralUtil::CreateR0(32); - auto x_global_or = client->TransferToServer(*x_literal); - auto y_global_or = client->TransferToServer(*y_literal); + auto x_global_or = client->TransferToServer(x_literal); + auto y_global_or = client->TransferToServer(y_literal); TF_EXPECT_OK(x_global_or.status()); TF_EXPECT_OK(y_global_or.status()); std::unique_ptr x_global = @@ -90,8 +90,8 @@ TEST(ConvertGraphDefToXla, Sum) { auto result_or = client->ExecuteAndTransfer(computation, {x_global.get(), y_global.get()}); TF_EXPECT_OK(result_or.status()); - std::unique_ptr result = std::move(result_or.ValueOrDie()); - EXPECT_EQ("(s32[]) (\n42\n)", result->ToString()); + xla::Literal result = std::move(result_or.ValueOrDie()); + EXPECT_EQ("(s32[]) (\n42\n)", result.ToString()); config.mutable_feed(0)->mutable_id()->set_output_index( 123); /* invalid output_index */ diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.cc b/tensorflow/compiler/tf2xla/tf2xla_util.cc index ebdf2fd741a49c5eb578e733218bd332ee480522..cc83db0562dd4ef1ae7b7a718a8f2e407acbfa1e 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_util.cc @@ -20,20 +20,23 @@ limitations under the License. #include #include +#include "absl/strings/str_cat.h" #include "absl/types/optional.h" #include "tensorflow/compiler/tf2xla/sharding_util.h" #include "tensorflow/compiler/tf2xla/tf2xla.pb.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/graph_def_util.h" +#include "tensorflow/core/framework/graph_to_functiondef.h" #include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/versions.pb.h" #include "tensorflow/core/graph/tensor_id.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/strings/strcat.h" namespace tensorflow { @@ -75,6 +78,8 @@ Status CheckFeedFetchNameConflicts(const string& kind, } // namespace +const char kXlaOutsideCompilationAttrName[] = "_xla_outside_compilation"; + Status ValidateConfig(const tf2xla::Config& config) { std::set names; for (const tf2xla::Feed& feed : config.feed()) { @@ -112,8 +117,8 @@ Status AddPlaceholdersForFeeds( const string name_port = TensorIdToString(feed->id()); PlaceholderInfo& info = placeholder_info[name_port]; info.feed = feed; - info.placeholder_name = strings::StrCat( - "aot_feed_", feed->id().output_index(), "/", feed->id().node_name()); + info.placeholder_name = absl::StrCat("aot_feed_", feed->id().output_index(), + "/", feed->id().node_name()); (*feed_remapping)[name_port] = info.placeholder_name; } @@ -233,7 +238,7 @@ Status PruneGraphDefInto(const tf2xla::Config& config, const GraphDef& in, // Push input nodes of the currently visited node to name_queue. for (const string& in_edge : map_entry.second->input()) { auto id = ParseTensorName(in_edge); - const string node_name = std::string(id.first); + const string node_name = string(id.first); if (feed_tensors.find(std::make_pair(node_name, id.second)) == feed_tensors.end()) { name_queue.push(node_name); @@ -258,7 +263,7 @@ Status PruneGraphDefInto(const tf2xla::Config& config, const GraphDef& in, } string TensorIdToString(const tf2xla::TensorId& id) { - return strings::StrCat(id.node_name(), ":", id.output_index()); + return absl::StrCat(id.node_name(), ":", id.output_index()); } Status SetNodeShardingFromNeighbors(Node* n, bool out_edges) { @@ -289,7 +294,7 @@ Status SetNodeShardingFromNeighbors(Node* n, bool out_edges) { return Status::OK(); } -void AddDtypeToKernalDefConstraint(StringPiece name, DataType dtype, +void AddDtypeToKernalDefConstraint(absl::string_view name, DataType dtype, KernelDef* kdef) { for (KernelDef::AttrConstraint& constraint : *kdef->mutable_constraint()) { if (constraint.name() == name) { @@ -323,4 +328,141 @@ uint32 GetXLARandomSeed() { return counter.fetch_add(2); } +// TODO(b/77601805): add tests for associated function related stuff. +bool HasAssociatedFunction(const NodeDef& node_def, + const FunctionLibraryDefinition* fld) { + if (fld->Contains(node_def.op())) { + return true; + } + + if (node_def.op() == FunctionLibraryDefinition::kGradientOp) { + // Gradient op has "f" attr, which is set to the function we are getting + // gradient for. We need to functionalize the gradient function. + return true; + } + + for (const auto& iter : node_def.attr()) { + if (iter.second.has_func()) { + return true; + } + } + + return false; +} + +std::vector GetAssociatedFunctions( + const Node& node, const FunctionLibraryDefinition* fld) { + std::vector results; + const string& op = node.type_string(); + if (fld->Contains(op)) { + // This is a function call node. + AttrValueMap attrs(node.attrs().begin(), node.attrs().end()); + results.emplace_back(AssociatedFunctionInfo::FunctionCall(op, attrs)); + } else if (node.type_string() == FunctionLibraryDefinition::kGradientOp) { + // This is a SymbolicGradient op. + AttrValueMap attrs(node.attrs().begin(), node.attrs().end()); + results.emplace_back(AssociatedFunctionInfo::SymbolicGradient(op, attrs)); + } else { + // Collect all function attrs for the node. + for (auto& iter : node.attrs()) { + if (iter.second.has_func()) { + VLOG(2) << "Found function attr for node " << node.name() << ": " + << iter.first << " = " << iter.second.func().name(); + results.emplace_back(AssociatedFunctionInfo::FunctionAttr( + iter.second.func().name(), iter.second.func().attr(), iter.first)); + } + } + } + return results; +} + +Status RewriteAssociatedFunction( + Graph* graph, Node* node, FunctionLibraryDefinition* fld, + const AssociatedFunctionInfo& associated_function, + const string& rewritten_function_name) { + switch (associated_function.type()) { + case AssociatedFunctionInfo::kFunctionCallNode: { + // Change this node to call the new function. + NodeDefBuilder builder(node->name(), rewritten_function_name, fld); + for (auto attr : node->attrs()) { + builder.Attr(attr.first, attr.second); + } + for (int i = 0; i < node->num_inputs(); i++) { + Node* input_node; + TF_RETURN_IF_ERROR(node->input_node(i, &input_node)); + builder.Input(input_node->name(), i, node->input_type(i)); + } + builder.Device(node->assigned_device_name().empty() + ? node->requested_device() + : node->assigned_device_name()); + NodeDef node_def; + TF_RETURN_IF_ERROR(builder.Finalize(&node_def)); + Status s; + Node* new_node = graph->AddNode(node_def, &s); + TF_RETURN_IF_ERROR(s); + for (auto edge : node->in_edges()) { + graph->AddEdge(edge->src(), edge->src_output(), new_node, + edge->dst_input()); + } + for (auto edge : node->out_edges()) { + graph->AddEdge(new_node, edge->src_output(), edge->dst(), + edge->dst_input()); + } + graph->RemoveNode(node); + break; + } + case AssociatedFunctionInfo::kSymbolicGradient: { + NameAttrList func; + TF_RETURN_IF_ERROR(GetNodeAttr( + node->attrs(), FunctionLibraryDefinition::kFuncAttr, &func)); + GradientDef gradient_def; + gradient_def.set_function_name(func.name()); + gradient_def.set_gradient_func(rewritten_function_name); + string original_grad_func = fld->FindGradient(func.name()); + if (original_grad_func.empty()) { + TF_RETURN_IF_ERROR(fld->AddGradientDef(gradient_def)); + } else if (original_grad_func != rewritten_function_name) { + TF_RETURN_IF_ERROR(fld->ReplaceGradient(gradient_def)); + } + break; + } + case AssociatedFunctionInfo::kFunctionAttr: { + // Change function attr to rewritten functions. + NameAttrList func; + TF_RETURN_IF_ERROR( + GetNodeAttr(node->attrs(), associated_function.attr_name(), &func)); + node->ClearAttr(associated_function.attr_name()); + func.set_name(rewritten_function_name); + node->AddAttr(associated_function.attr_name(), func); + break; + } + } + + return Status::OK(); +} + +Status CachedFunctionHandles::GetOrInstantiate( + const string& func_name, AttrSlice attrs, + FunctionLibraryRuntime::Handle* handle) { + string canonicalized_name = Canonicalize(func_name, attrs); + auto iter = handles_.find(canonicalized_name); + if (iter != handles_.end()) { + *handle = iter->second; + return Status::OK(); + } + + TF_RETURN_IF_ERROR(flr_->Instantiate(func_name, attrs, handle)); + handles_[canonicalized_name] = *handle; + return Status::OK(); +} + +Status CachedFunctionHandles::ReleaseAllHandles() { + Status result; + for (auto iter : handles_) { + result.Update(flr_->ReleaseHandle(iter.second)); + } + handles_.clear(); + return result; +} + } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.h b/tensorflow/compiler/tf2xla/tf2xla_util.h index 33620ef810bd4fe897f384474e661e341a448b93..b974b998229982afc9168dcaf0799cfddd965a04 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util.h +++ b/tensorflow/compiler/tf2xla/tf2xla_util.h @@ -19,6 +19,7 @@ limitations under the License. #include #include "tensorflow/compiler/tf2xla/tf2xla.pb.h" +#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/kernel_def.pb.h" #include "tensorflow/core/framework/op.h" @@ -53,12 +54,120 @@ string TensorIdToString(const tf2xla::TensorId& id); Status SetNodeShardingFromNeighbors(Node* n, bool out_edges); // Add an allowed data type to the AttrConstraint with the given name. -void AddDtypeToKernalDefConstraint(StringPiece name, DataType dtype, +void AddDtypeToKernalDefConstraint(absl::string_view name, DataType dtype, KernelDef* kdef); // Returns the next random seed to use for seeding xla rng. uint32 GetXLARandomSeed(); +// Indicates how a FunctionDef is associated with a graph node (e.g. the node is +// a function call, or the node has function attrs). +class AssociatedFunctionInfo { + public: + enum AssociatedFunctionType { + kFunctionAttr = 0, + kFunctionCallNode = 1, + kSymbolicGradient = 2, + }; + + // The function is an attr of the node. + static AssociatedFunctionInfo FunctionAttr(const string& func_name, + const AttrValueMap& attrs, + const string& attr_name) { + return AssociatedFunctionInfo(kFunctionAttr, func_name, attrs, attr_name); + } + + // The node is a function call. + static AssociatedFunctionInfo FunctionCall(const string& func_name, + const AttrValueMap& attrs) { + // attr_name will not be used in this case. + return AssociatedFunctionInfo(kFunctionCallNode, func_name, attrs, + /*attr_name=*/""); + } + + // The node is a SymbolicGradient op. + static AssociatedFunctionInfo SymbolicGradient(const string& func_name, + const AttrValueMap& attrs) { + // attr_name will not be used in this case. + return AssociatedFunctionInfo(kSymbolicGradient, func_name, attrs, + /*attr_name=*/""); + } + + AssociatedFunctionType type() const { return type_; } + + const string& func_name() const { return func_name_; } + + const string& attr_name() const { return attr_name_; } + + const AttrValueMap& attrs() const { return attrs_; } + + private: + AssociatedFunctionInfo(AssociatedFunctionType type, const string& func_name, + const AttrValueMap& attrs, const string& attr_name) + : type_(type), + func_name_(func_name), + attrs_(attrs), + attr_name_(attr_name) {} + + // Available for all instances. + AssociatedFunctionType type_; + string func_name_; + AttrValueMap attrs_; + + // Only available if the function is defined in an attr. + string attr_name_; +}; + +// Returns if the NodeDef has associated function. +bool HasAssociatedFunction(const NodeDef& node_def, + const FunctionLibraryDefinition* fld); + +// Gets functions associated with the node. Current cases: +// 1. For function call node, its function name; +// 2. For SymbolicGradient op, returned func_name will be "SymbolicGradient", +// and returned attrs will be this node's attributes; +// 3. For nodes like XlaWhile/XlaIf, all their function attributes. +std::vector GetAssociatedFunctions( + const Node& node, const FunctionLibraryDefinition* fld); + +// Changes associated functions for the node. Current cases: +// 1. For function call node, creates a new node with the new function name and +// remove the old node; +// 2. For SymbolicGradient op, add or replace GradientDef in +// FunctionLibraryDefinition; +// 3. For nodes like XlaWhile/XlaIf, modify their function attributes. +Status RewriteAssociatedFunction( + Graph* graph, Node* node, FunctionLibraryDefinition* fld, + const AssociatedFunctionInfo& associated_function, + const string& rewritten_function_name); + +// Attribute to mark nodes to be executed on host. +extern const char kXlaOutsideCompilationAttrName[]; + +// Class to act as cache for FunctionLibraryRuntime::Handle objects. +class CachedFunctionHandles { + public: + CachedFunctionHandles(FunctionLibraryRuntime* flr) : flr_(flr) {} + + // Populates `handle` for requested function and attributes. If we have + // instantiated the function with the same attributes before, `handle` will be + // cached handle; otherwise instantiate the function and populate `handle`. + Status GetOrInstantiate(const string& func_name, AttrSlice attrs, + FunctionLibraryRuntime::Handle* handle); + + // Releases all handles in the cache. Returns first non-OK status if any; + // returns OK otherwise. + Status ReleaseAllHandles(); + + ~CachedFunctionHandles() { ReleaseAllHandles().IgnoreError(); } + + private: + FunctionLibraryRuntime* flr_; + std::map handles_; + + TF_DISALLOW_COPY_AND_ASSIGN(CachedFunctionHandles); +}; + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_TF2XLA_TF2XLA_UTIL_H_ diff --git a/tensorflow/compiler/tf2xla/tf2xla_util_test.cc b/tensorflow/compiler/tf2xla/tf2xla_util_test.cc index 2b1f724dc7b2e2bb6d06115827f92bf0670955b3..202e929315cacd4d6cdfc69d50639d8a427ec6c2 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util_test.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_util_test.cc @@ -16,18 +16,22 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/ops/data_flow_ops.h" #include "tensorflow/cc/ops/function_ops.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/compiler/tf2xla/sharding_util.h" +#include "tensorflow/core/common_runtime/graph_optimizer.h" +#include "tensorflow/core/common_runtime/process_function_library_runtime.h" +#include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/core/stringpiece.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/test.h" +#include "tensorflow/core/public/version.h" namespace tensorflow { namespace { @@ -153,7 +157,7 @@ static tf2xla::Config FetchesConfig(std::vector fetches) { tf2xla::Config config; for (const auto& fetch_node_name : fetches) { auto* fetch = config.add_fetch(); - fetch->set_name(strings::StrCat("fetch_", fetch_node_name)); + fetch->set_name(absl::StrCat("fetch_", fetch_node_name)); fetch->mutable_id()->set_node_name(fetch_node_name); } return config; @@ -255,5 +259,75 @@ TEST(SetNodeShardingFromNeighbors, Basic) { EXPECT_EQ(1, parse_status.ValueOrDie().value().tile_assignment_devices(0)); } +REGISTER_OP("One") + .Output("y: T") + .Attr("T: {float, double, int32, int64}") + .Doc(R"doc( +Returns a tensor with a single element (1) of type T. + +y: A scalar in type T. + +)doc"); + +// Tests that CachedFunctionHandles class works. +TEST(CachedFunctionHandles, Basic) { + FunctionDef func = FunctionDefHelper::Define( + // Name + "TestFunc", + // Args + {}, + // Return values + {"y:T"}, + // Attr def + {"T:{float, double, int32, int64}"}, + // Nodes + { + {{"y"}, "One", {}, {{"T", "$T"}}}, + }); + FunctionDefLibrary proto; + *proto.add_function() = func; + FunctionLibraryDefinition fld(OpRegistry::Global(), proto); + std::unique_ptr pflr( + new ProcessFunctionLibraryRuntime( + /*device_mgr=*/nullptr, Env::Default(), TF_GRAPH_DEF_VERSION, &fld, + OptimizerOptions())); + FunctionLibraryRuntime* flr = + pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice); + + CachedFunctionHandles cached_function_handles(flr); + + // Tests that GetOrInstantiate() works. + FunctionLibraryRuntime::Handle first_handle; + AttrValue attr; + attr.set_type(DT_FLOAT); + AttrValueMap attrs; + attrs["T"] = attr; + TF_ASSERT_OK(cached_function_handles.GetOrInstantiate( + "TestFunc", AttrSlice(&attrs), &first_handle)); + + // Tests that we can get FunctionBody. + const FunctionBody* body = flr->GetFunctionBody(first_handle); + EXPECT_NE(body, nullptr); + + // Tests that GetOrInstantiate() returns cached handle when called with same + // function name and attributes. + FunctionLibraryRuntime::Handle second_handle; + TF_ASSERT_OK(cached_function_handles.GetOrInstantiate( + "TestFunc", AttrSlice(&attrs), &second_handle)); + EXPECT_EQ(first_handle, second_handle); + + // Tests that GetOrInstantiate() returns new handle when called with same + // function name but different attributes. + attr.set_type(DT_INT32); + attrs["T"] = attr; + FunctionLibraryRuntime::Handle third_handle; + TF_ASSERT_OK(cached_function_handles.GetOrInstantiate( + "TestFunc", AttrSlice(&attrs), &third_handle)); + EXPECT_NE(first_handle, third_handle); + + // Tests that ReleaseAllHandles() works. + TF_EXPECT_OK(cached_function_handles.ReleaseAllHandles()); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/type_util.cc b/tensorflow/compiler/tf2xla/type_util.cc index c969212a1bfaa6cab0d896ee074cfd4e2b283ae4..d00b1376620c0c9d112c7d7426758f6d3f25e86f 100644 --- a/tensorflow/compiler/tf2xla/type_util.cc +++ b/tensorflow/compiler/tf2xla/type_util.cc @@ -26,21 +26,26 @@ Status DataTypeToPrimitiveType(DataType data_type, xla::PrimitiveType* type) { *type = xla::PRED; return Status::OK(); case tensorflow::DT_INT8: + case tensorflow::DT_QINT8: *type = xla::S8; return Status::OK(); case tensorflow::DT_INT16: + case tensorflow::DT_QINT16: *type = xla::S16; return Status::OK(); case tensorflow::DT_INT32: + case tensorflow::DT_QINT32: *type = xla::S32; return Status::OK(); case tensorflow::DT_INT64: *type = xla::S64; return Status::OK(); case tensorflow::DT_UINT8: + case tensorflow::DT_QUINT8: *type = xla::U8; return Status::OK(); case tensorflow::DT_UINT16: + case tensorflow::DT_QUINT16: *type = xla::U16; return Status::OK(); case tensorflow::DT_UINT32: @@ -64,12 +69,6 @@ Status DataTypeToPrimitiveType(DataType data_type, xla::PrimitiveType* type) { case tensorflow::DT_COMPLEX64: *type = xla::C64; return Status::OK(); - case tensorflow::DT_QUINT8: - *type = xla::U8; - return Status::OK(); - case tensorflow::DT_QINT32: - *type = xla::S32; - return Status::OK(); default: return errors::InvalidArgument( "Unsupported type in DataTypeToPrimitiveType ", diff --git a/tensorflow/compiler/tf2xla/type_util.h b/tensorflow/compiler/tf2xla/type_util.h index bda667eb1f16b80da415c7c5205df96a4ae93e4c..6354216eee7978dc2b4a59f5792a70f67d530b9b 100644 --- a/tensorflow/compiler/tf2xla/type_util.h +++ b/tensorflow/compiler/tf2xla/type_util.h @@ -25,6 +25,14 @@ namespace tensorflow { // Converts a Tensorflow DataType to an XLA PrimitiveType. Status DataTypeToPrimitiveType(DataType data_type, xla::PrimitiveType* type); +// N.B.: there is intentionally no function to convert an XLA PrimitiveType to +// a TensorFlow DataType. The mapping from TF types to XLA types is not +// one-to-one: for example, both DT_INT8 and DT_QINT8 map to xla::S8. So the +// inverse would not be a well-defined function. If you find that you want the +// inverse mapping, then most likely you should be preserving the original +// TensorFlow type, rather than trying to convert an XLA type into a TensorFlow +// type. + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_TF2XLA_TYPE_UTIL_H_ diff --git a/tensorflow/compiler/tf2xla/xla_compilation_device.cc b/tensorflow/compiler/tf2xla/xla_compilation_device.cc index d98237bd5c9288e6337e10c19c2d7574ad2e4c97..7f860500c75667a920505dbf498e3da4b388fb90 100644 --- a/tensorflow/compiler/tf2xla/xla_compilation_device.cc +++ b/tensorflow/compiler/tf2xla/xla_compilation_device.cc @@ -76,12 +76,11 @@ class XlaCompilationAllocator : public Allocator { XlaCompilationDevice::XlaCompilationDevice(const SessionOptions& options, DeviceType type) - : LocalDevice( - options, - Device::BuildDeviceAttributes( - strings::StrCat("/device:", type.type(), ":0"), type, - Bytes(256 << 20), DeviceLocality(), - strings::StrCat("device: XLA compilation device ", type.type()))), + : LocalDevice(options, Device::BuildDeviceAttributes( + absl::StrCat("/device:", type.type(), ":0"), + type, Bytes(256 << 20), DeviceLocality(), + absl::StrCat("device: XLA compilation device ", + type.type()))), allocator_(new XlaCompilationAllocator()) {} XlaCompilationDevice::~XlaCompilationDevice() {} diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index eabfc6b6e26f7e6ab41c8744b2b10d8ea13bd3ca..b2c57e88803e0661a9a514f844dff97ff9edf2ea 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -20,10 +20,10 @@ limitations under the License. #include "absl/memory/memory.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" -#include "tensorflow/compiler/tf2xla/functionalize_control_flow.h" #include "tensorflow/compiler/tf2xla/graph_compiler.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/sharding_util.h" +#include "tensorflow/compiler/tf2xla/side_effect_util.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_compilation_device.h" @@ -149,6 +149,9 @@ Status XlaCompiler::FindFunctionBody(const NameAttrList& function, TF_RETURN_WITH_CONTEXT_IF_ERROR( GetFunctionBody(function, flib_runtime_, fbody), "Local lookup failed with: ", status.error_message()); + VLOG(4) << "Function " << function.name() << " in flib_runtime_"; + } else { + VLOG(4) << "Function " << function.name() << " in local_flib_runtime_"; } return Status::OK(); } @@ -191,6 +194,17 @@ Status XlaCompiler::CompileFunction(const XlaCompiler::CompileOptions& options, std::unique_ptr graph = GetGraph(fbody); + // Clear the "_kernel" attribute if it is set to "host". This is used to + // indicate that a computation should happen on the host instead of the + // accelerator, but doesn't make sense in XLA. + const char* const kKernelAttr = "_kernel"; + for (Node* n : graph->nodes()) { + string value; + if (GetNodeAttrSimple(n->attrs(), kKernelAttr, &value) && value == "host") { + n->ClearAttr(kKernelAttr); + } + } + // _Arg and _Retval nodes don't exist in the stored subgraph for the function; // they are added by the function body looked up. Therefore, they don't have // core assignments here. @@ -198,14 +212,14 @@ Status XlaCompiler::CompileFunction(const XlaCompiler::CompileOptions& options, // lowest-numbered core that consumes the argument. We choose the // lowest-numbered core so the assignment is deterministic. for (Node* n : graph->nodes()) { - if (StringPiece(n->type_string()) == "_Arg") { + if (absl::string_view(n->type_string()) == "_Arg") { TF_RETURN_IF_ERROR(SetNodeShardingFromNeighbors(n, /*out_edges=*/true)); } } // Do _Retval as a second loop, in case the retval's input is an _Arg (which // may have gotten a device assignment from the first loop). for (Node* n : graph->nodes()) { - if (StringPiece(n->type_string()) == "_Retval") { + if (absl::string_view(n->type_string()) == "_Retval") { TF_RETURN_IF_ERROR(SetNodeShardingFromNeighbors(n, /*out_edges=*/false)); } } @@ -213,8 +227,7 @@ Status XlaCompiler::CompileFunction(const XlaCompiler::CompileOptions& options, if (VLOG_IS_ON(2)) { VLOG(2) << "XlaCompiler::CompileFunction: " << dump_graph::DumpGraphToFile( - strings::StrCat("xla_compile_function_", function_id), - *graph); + absl::StrCat("xla_compile_function_", function_id), *graph); } VLOG(1) << "===================================================="; @@ -292,6 +305,10 @@ Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg, "Invalid resource type in XLAShapeForArgument()"); } } + case XlaCompiler::Argument::kToken: { + *xla_shape = xla::ShapeUtil::MakeTokenShape(); + return Status::OK(); + } case XlaCompiler::Argument::kInvalid: return errors::Internal("Invalid argument type in XLAShapeForArgument()"); } @@ -319,8 +336,7 @@ Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr graph, step_container->name(), XlaContext::kXlaContextResourceName, xla_context)); - GraphCompiler graph_compiler(xla_context, device, graph.get(), flib, - step_container.get()); + GraphCompiler graph_compiler(device, graph.get(), flib, step_container.get()); TF_RETURN_IF_ERROR(graph_compiler.Compile()); // Explicitly clean up the step container, to capture the cleanup status. step_container.reset(); @@ -328,10 +344,8 @@ Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr graph, } // Builds the XLA computation. -// -// `retvals` is the list of retvals produced by _Retval operators, in index -// order. `variable_map` is a map from variable ID numbers to XlaOpContext -// variable states, generated by the symbolic evaluation. +// `args` is the list of input arguments, `retvals` is the list of retvals +// produced by _Retval operators, in index order. // If `return_updated_values_for_all_resources` is true, all resources will be // included in `resource_updates`, regardless of whether their value changed. // Sets `*num_nonconst_outputs` to the number of outputs of the `computation`. @@ -361,6 +375,9 @@ Status BuildComputation( if (retval.has_constant_value()) { output.is_constant = true; output.constant_value = retval.constant_value(); + } else if (retval.resource() != nullptr) { + output.is_constant = false; + output.input_index = retval.resource()->arg_num(); } else { output.is_constant = false; elems.push_back(retval.handle()); @@ -487,7 +504,8 @@ Status XlaCompiler::BuildArguments( } break; - case XlaCompiler::Argument::kParameter: { + case XlaCompiler::Argument::kParameter: + case XlaCompiler::Argument::kToken: { input_mapping->push_back(i); break; } @@ -495,7 +513,8 @@ Status XlaCompiler::BuildArguments( arg_expression.set_constant_value(arg.constant_value); break; case XlaCompiler::Argument::kInvalid: - return errors::Internal("Unreachable case in BuildArguments()"); + return errors::Internal( + "Unreachable case in BuildArguments() while filling constant args"); } } @@ -518,7 +537,7 @@ Status XlaCompiler::BuildArguments( // Use the _Arg nodes in the graph to resolve core assignments. for (const Node* n : graph.nodes()) { - if (StringPiece(n->type_string()) != "_Arg") continue; + if (absl::string_view(n->type_string()) != "_Arg") continue; int index; TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index)); TF_RET_CHECK(index >= 0 && index < args.size()) @@ -577,7 +596,7 @@ Status XlaCompiler::BuildArguments( builder, core == -1 ? absl::optional() : xla::sharding_builder::AssignDevice(core)); arg_handles[i] = xla::Parameter(builder, i, (*input_shapes)[i], - strings::StrCat("arg", i)); + absl::StrCat("arg", i)); } } @@ -613,9 +632,14 @@ Status XlaCompiler::BuildArguments( arg_expression.set_handle(arg_handles[i]); } break; + case XlaCompiler::Argument::kToken: { + arg_expression.set_handle(arg_handles[i]); + break; + } case XlaCompiler::Argument::kConstant: case XlaCompiler::Argument::kInvalid: - return errors::Internal("Unreachable case in BuildArguments()"); + return errors::Internal( + "Unreachable case in BuildArguments() while filling handles"); } } @@ -639,7 +663,7 @@ Status XlaCompiler::CompileSingleOp( // dependency edge to the _SOURCE node. for (int64 i = 0; i < ctx->num_inputs(); ++i) { Node* node; - string name = strings::StrCat(ctx->op_kernel().name(), "_", i, "_arg"); + string name = absl::StrCat(ctx->op_kernel().name(), "_", i, "_arg"); Status status = NodeBuilder(name, "_Arg") .ControlInput(graph->source_node()) .Attr("T", ctx->input_dtype(i)) @@ -652,7 +676,7 @@ Status XlaCompiler::CompileSingleOp( // Similarly with return values, create dummy _Retval nodes fed by `node`. for (int64 i = 0; i < ctx->num_outputs(); ++i) { Node* node; - string name = strings::StrCat(ctx->op_kernel().name(), "_", i, "_retval"); + string name = absl::StrCat(ctx->op_kernel().name(), "_", i, "_retval"); Status status = NodeBuilder(name, "_Retval") .Input(main_node, i) .Attr("T", ctx->expected_output_dtype(i)) @@ -688,7 +712,7 @@ Status ValidateGraph(const Graph* graph, const DeviceType& device_type, const string& name) { auto maybe_error = [&](const Node* node, const Status& s) -> Status { if (!s.ok()) { - return errors::InvalidArgument(strings::StrCat( + return errors::InvalidArgument(absl::StrCat( "Detected unsupported operations when trying to compile graph ", name, " on ", device_type.type_string(), ": ", node->def().op(), " (", s.error_message(), ")", FormatNodeForError(*node))); @@ -729,18 +753,13 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, if (VLOG_IS_ON(2)) { VLOG(2) << "XlaCompiler::CompileGraph: " << dump_graph::DumpGraphToFile( - strings::StrCat("xla_compile_graph_", name), *graph); + absl::StrCat("xla_compile_graph_", name), *graph, + flib_runtime_->GetFunctionLibraryDefinition()); } // Report the error here if initialization failed. TF_RETURN_IF_ERROR(initialization_status_); - // Converts Tensorflow's graph control-flow constructs into functional - // control-flow that can be compiled into XLA code. - TF_RETURN_IF_ERROR( - FunctionalizeControlFlow(flib_runtime_->GetFunctionLibraryDefinition(), - graph.get(), local_flib_def_.get())); - // Detect invalid nodes. // FunctionalizeControlFlow may remove some nodes from the graph. TF_RETURN_IF_ERROR(ValidateGraph(graph.get(), *options_.flib_def, @@ -753,23 +772,71 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, &options_.shape_representation_fn); core::ScopedUnref context_unref(context); + std::vector real_args(args); + int token_input_index = -1; + if (options.add_token_input_output) { + // Add extra token input. + token_input_index = real_args.size(); + + XlaCompiler::Argument token_arg; + token_arg.kind = XlaCompiler::Argument::kToken; + real_args.push_back(token_arg); + } + std::vector arg_expressions; std::vector arg_cores; - TF_RETURN_IF_ERROR( - BuildArguments(*graph, args, options.use_tuple_arg, &builder, context, - &arg_cores, &arg_expressions, &result->input_mapping, - &result->xla_input_shapes, options.is_entry_computation)); + TF_RETURN_IF_ERROR(BuildArguments( + *graph, real_args, options.use_tuple_arg, &builder, context, &arg_cores, + &arg_expressions, &result->input_mapping, &result->xla_input_shapes, + options.is_entry_computation)); context->set_args(std::move(arg_expressions)); + PushNodeTokenMapping(); + // Use std::set instead of std::unordered_set to ensure determinism. + std::set output_node_token_inputs; + if (token_input_index != -1) { + // Original token comes from input. + auto arg_expression = context->args()[token_input_index]; + TF_RETURN_IF_ERROR( + SetNodeToken(kXlaTokenArgNodeName, arg_expression.handle())); + + // Calculate token inputs for output token. + output_node_token_inputs = CalculateTokenInputsForOutputToken(*graph); + + // If there's no side-effecting op in the graph, use token input as token + // output. + if (output_node_token_inputs.empty()) { + output_node_token_inputs.insert(kXlaTokenArgNodeName); + } + } else if (options.is_entry_computation) { + // Original token is manually created. + if (HasSideEffectingNodes(*graph)) { + TF_RETURN_IF_ERROR( + SetNodeToken(kXlaTokenArgNodeName, xla::CreateToken(&builder))); + } + } + TF_RETURN_IF_ERROR(ExecuteGraph(context, std::move(graph), device_, flib_runtime_, NextStepId())); + if (token_input_index != -1) { + // Add extra token output. + std::vector token_inputs; + for (const auto& node_name : output_node_token_inputs) { + auto token_or = GetNodeToken(node_name); + TF_RETURN_IF_ERROR(token_or.status()); + token_inputs.push_back(token_or.ValueOrDie()); + } + TF_RETURN_IF_ERROR( + context->AppendTokenRetval(xla::AfterAll(&builder, token_inputs))); + } + TF_RETURN_IF_ERROR(PopNodeTokenMapping()); int num_nonconst_outputs; int num_computation_outputs; result->computation = std::make_shared(); result->outputs.resize(context->retvals().size()); TF_RETURN_IF_ERROR(BuildComputation( - args, arg_cores, context->retvals(), context->resources(), + real_args, arg_cores, context->retvals(), context->resources(), options.return_updated_values_for_all_resources, options.always_return_tuple, &builder, result->computation.get(), &num_computation_outputs, &num_nonconst_outputs, &result->outputs, @@ -830,8 +897,8 @@ Status XlaCompiler::GetDeviceToHostChannelHandle(const string& key, namespace { -void SetTransfer(const string& key, gtl::ArraySlice types, - gtl::ArraySlice shapes, +void SetTransfer(const string& key, absl::Span types, + absl::Span shapes, tf2xla::HostTransferMetadata* transfer) { transfer->set_key(key); CHECK(types.size() == shapes.size()); @@ -845,8 +912,8 @@ void SetTransfer(const string& key, gtl::ArraySlice types, } // namespace Status XlaCompiler::SetDeviceToHostMetadata( - const string& key, gtl::ArraySlice types, - gtl::ArraySlice shapes) { + const string& key, absl::Span types, + absl::Span shapes) { if (host_compute_sends_.find(key) != host_compute_sends_.end()) { return errors::InvalidArgument( "Duplicate calls to SetDeviceToHostMetadata with key ", key); @@ -872,8 +939,8 @@ Status XlaCompiler::GetDeviceToHostShapes( } Status XlaCompiler::SetHostToDeviceMetadata( - const string& key, gtl::ArraySlice types, - gtl::ArraySlice shapes) { + const string& key, absl::Span types, + absl::Span shapes) { if (host_compute_recvs_.find(key) != host_compute_sends_.end()) { return errors::InvalidArgument( "Duplicate calls to SetHostToDeviceMetadata with key ", key); @@ -908,4 +975,47 @@ Status XlaCompiler::SetHostComputeControlDependency( return Status::OK(); } +void XlaCompiler::PushNodeTokenMapping() { + node_token_mapping_stack_.emplace(std::map{}); +} + +Status XlaCompiler::PopNodeTokenMapping() { + if (node_token_mapping_stack_.empty()) { + return errors::FailedPrecondition( + "Calling PopNodeTokenMapping() when node_token_mapping_stack_ is " + "empty."); + } + node_token_mapping_stack_.pop(); + return Status::OK(); +} + +Status XlaCompiler::SetNodeToken(const string& node_name, + const xla::XlaOp& op) { + if (node_token_mapping_stack_.empty()) { + return errors::FailedPrecondition( + "Calling SetNodeToken() when node_token_mapping_stack_ is " + "empty."); + } + auto insert_result = node_token_mapping_stack_.top().insert({node_name, op}); + if (!insert_result.second) { + return errors::FailedPrecondition("Token mapping already exists for node ", + node_name); + } + return Status::OK(); +} + +xla::StatusOr XlaCompiler::GetNodeToken(const string& node_name) { + if (node_token_mapping_stack_.empty()) { + return errors::FailedPrecondition( + "Calling GetNodeToken() when node_token_mapping_stack_ is " + "empty."); + } + auto iter = node_token_mapping_stack_.top().find(node_name); + if (iter == node_token_mapping_stack_.top().end()) { + return errors::FailedPrecondition("Cannot find token mapping for node ", + node_name); + } + return iter->second; +} + } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h index da1ae02f324fbaf4079e04fa128215c2114522b0..2cc603a58016a509fafdf6f95423dd6c0864cce3 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.h +++ b/tensorflow/compiler/tf2xla/xla_compiler.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILER_H_ #define TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILER_H_ +#include + #include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h" #include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" @@ -26,6 +28,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/function.h" +#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/notification.h" @@ -106,6 +109,9 @@ class XlaCompiler { // Argument is a run-time parameter. kParameter, + + // Argument is an XLA token. + kToken, }; Kind kind = kInvalid; @@ -179,10 +185,15 @@ class XlaCompiler { // True when compiling the entry computation, false for subcomputations // (while, call, etc.) bool is_entry_computation = true; + + // True when we should add XLA input & output to the graph/function. + bool add_token_input_output = false; }; struct OutputDescription { // Type and shape of the output. The shape is the unflattened shape. + // When `type` is DT_RESOURCE, `shape` is the shape of the resource + // variable's value. DataType type; TensorShape shape; @@ -190,6 +201,10 @@ class XlaCompiler { // 'Tensor' is in host memory. bool is_constant = false; Tensor constant_value; + + // When this output is a resource, i.e. `type == DT_RESOURCE`, this is + // the index of the input that contains the resource. + int input_index; }; // Describes a variable write side effect of the computation. @@ -345,8 +360,8 @@ class XlaCompiler { // Sets the shapes and types for the device to host transfer associated with // 'key'. Status SetDeviceToHostMetadata(const string& key, - gtl::ArraySlice types, - gtl::ArraySlice shapes); + absl::Span types, + absl::Span shapes); // Gets the shapes the device to host transfer associated with 'key'. Status GetDeviceToHostShapes(const string& key, @@ -355,8 +370,8 @@ class XlaCompiler { // Sets the shapes and types for the host to device transfer associated with // 'key'. Status SetHostToDeviceMetadata(const string& key, - gtl::ArraySlice types, - gtl::ArraySlice shapes); + absl::Span types, + absl::Span shapes); // In order to avoid deadlocks from dependencies in host computations, it can // be necessary to enforce a partial order on the execution of HostCompute @@ -378,6 +393,11 @@ class XlaCompiler { xla::Client* client() const { return options_.client; } FunctionLibraryRuntime* flib_runtime() const { return flib_runtime_; } + void PushNodeTokenMapping(); + Status PopNodeTokenMapping(); + Status SetNodeToken(const string& node_name, const xla::XlaOp& op); + xla::StatusOr GetNodeToken(const string& node_name); + private: // Sets the function body `fbody` to the one registered as `function`. Status FindFunctionBody(const NameAttrList& function, @@ -442,6 +462,15 @@ class XlaCompiler { std::unordered_map host_compute_control_output_; + // This is used to store mapping. Side-effecting + // ops call SetNodeToken() to record its token output, so later side-effecting + // ops can use GetNodeToken() to get it and use it as token input. + // + // It's a stack because we need a mapping like this for each level of nested + // CompileGraph() call. In CompileGraph(), we will push a new mapping to the + // stack, and pop the mapping before returning. + std::stack> node_token_mapping_stack_; + TF_DISALLOW_COPY_AND_ASSIGN(XlaCompiler); }; diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc index 740f6dc25cdce027341aaba7e4da27ac8d55ed94..72b17d04fc42eb00781e96b412465b73fb29a5c2 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc @@ -20,10 +20,12 @@ limitations under the License. #include "tensorflow/cc/ops/function_ops.h" #include "tensorflow/cc/ops/resource_variable_ops.h" #include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/compiler/tf2xla/side_effect_util.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -32,6 +34,7 @@ limitations under the License. #include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/function_testlib.h" +#include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_testutil.h" @@ -205,27 +208,22 @@ TEST_F(XlaCompilerTest, Simple) { std::move(graph), args, &result)); // Tests that the generated computation works. - std::unique_ptr param0_literal = - xla::LiteralUtil::CreateR1({7, 42}); - std::unique_ptr param1_literal = - xla::LiteralUtil::CreateR1({-3, 101}); + xla::Literal param0_literal = xla::LiteralUtil::CreateR1({7, 42}); + xla::Literal param1_literal = xla::LiteralUtil::CreateR1({-3, 101}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); std::unique_ptr param1_data = - client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); + client_->TransferToServer(param1_literal).ConsumeValueOrDie(); std::unique_ptr actual = client_ ->Execute(*result.computation, {param0_data.get(), param1_data.get()}) .ConsumeValueOrDie(); - std::unique_ptr actual_literal = - client_->Transfer(*actual).ConsumeValueOrDie(); - - std::unique_ptr expected0 = - xla::LiteralUtil::CreateR1({4, 143}); - std::unique_ptr expected_literal = - xla::LiteralUtil::MakeTuple({expected0.get()}); - EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); + xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie(); + + xla::Literal expected0 = xla::LiteralUtil::CreateR1({4, 143}); + xla::Literal expected_literal = xla::LiteralUtil::MakeTuple({&expected0}); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal)); } // Tests compilation of a graph where the _Retval node is not necessarily last @@ -261,23 +259,20 @@ TEST_F(XlaCompilerTest, OutOfOrderGraph) { args, &result)); // Tests that the generated computation works. - std::unique_ptr param0_literal = - xla::LiteralUtil::CreateR1({7, 42}); - std::unique_ptr param1_literal = - xla::LiteralUtil::CreateR1({-3, 101}); + xla::Literal param0_literal = xla::LiteralUtil::CreateR1({7, 42}); + xla::Literal param1_literal = xla::LiteralUtil::CreateR1({-3, 101}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); std::unique_ptr param1_data = - client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); + client_->TransferToServer(param1_literal).ConsumeValueOrDie(); std::unique_ptr actual = client_ ->Execute(*result.computation, {param0_data.get(), param1_data.get()}) .ConsumeValueOrDie(); - std::unique_ptr actual_literal = - client_->Transfer(*actual).ConsumeValueOrDie(); + xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie(); - EXPECT_TRUE(xla::LiteralTestUtil::Equal(*param0_literal, *actual_literal)); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(param0_literal, actual_literal)); } // Tests that the compiler doesn't reorder the parameters. @@ -405,23 +400,19 @@ TEST_F(XlaCompilerTest, ConstantOutputs) { EXPECT_FALSE(result.outputs[1].is_constant); // Tests that the generated computation works. - std::unique_ptr param0_literal = - xla::LiteralUtil::CreateR1({7, 42}); + xla::Literal param0_literal = xla::LiteralUtil::CreateR1({7, 42}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); std::unique_ptr actual = client_->Execute(*result.computation, {param0_data.get()}) .ConsumeValueOrDie(); - std::unique_ptr actual_literal = + xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie(); - std::unique_ptr expected0 = - xla::LiteralUtil::CreateR1({-7, -42}); - std::unique_ptr expected_literal = - xla::LiteralUtil::MakeTuple({expected0.get()}); - EXPECT_TRUE( - xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); + xla::Literal expected0 = xla::LiteralUtil::CreateR1({-7, -42}); + xla::Literal expected_literal = xla::LiteralUtil::MakeTuple({&expected0}); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal)); } { @@ -440,24 +431,21 @@ TEST_F(XlaCompilerTest, ConstantOutputs) { EXPECT_FALSE(result.outputs[1].is_constant); // Tests that the generated computation works. - std::unique_ptr param0_literal = - xla::LiteralUtil::CreateR1({7, 42}); + xla::Literal param0_literal = xla::LiteralUtil::CreateR1({7, 42}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); std::unique_ptr actual = client_->Execute(*result.computation, {param0_data.get()}) .ConsumeValueOrDie(); - std::unique_ptr actual_literal = + xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie(); - std::unique_ptr expected0 = - xla::LiteralUtil::CreateR0(7); - std::unique_ptr expected1 = - xla::LiteralUtil::CreateR1({-7, -42}); - std::unique_ptr expected = - xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()}); - EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected, *actual_literal)); + xla::Literal expected0 = xla::LiteralUtil::CreateR0(7); + xla::Literal expected1 = xla::LiteralUtil::CreateR1({-7, -42}); + xla::Literal expected = + xla::LiteralUtil::MakeTuple({&expected0, &expected1}); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected, actual_literal)); } } @@ -616,10 +604,17 @@ TEST_F(XlaCompilerTest, DeterministicCompilation) { auto instr1 = c1.instructions(j); auto instr2 = c2.instructions(j); instr1.clear_name(); + instr1.clear_id(); + instr1.clear_operand_ids(); instr2.clear_name(); - // The names of instructions were uniquified by the XlaBuilder, the rest - // of the fields should be identical. + instr2.clear_id(); + instr2.clear_operand_ids(); + // The names of instructions were uniquified by the XlaBuilder and the + // unique ids may be different, the rest of the fields should be + // identical. string str1, str2; + LOG(INFO) << "instr1 = " << instr1.DebugString(); + LOG(INFO) << "instr2 = " << instr2.DebugString(); instr1.AppendPartialToString(&str1); instr2.AppendPartialToString(&str2); EXPECT_EQ(str1, str2); @@ -669,34 +664,26 @@ TEST_F(XlaCompilerTest, CanPassTensorArraysToAndFromComputation) { update.tensor_array_gradients_accessed); // Tests that the generated computation works. - std::unique_ptr input_base = - xla::LiteralUtil::CreateR1({7, 42}); - std::unique_ptr input_grad2 = - xla::LiteralUtil::CreateR1({-3, 101}); - std::unique_ptr input = - xla::LiteralUtil::MakeTuple({input_base.get(), input_grad2.get()}); + xla::Literal input_base = xla::LiteralUtil::CreateR1({7, 42}); + xla::Literal input_grad2 = xla::LiteralUtil::CreateR1({-3, 101}); + xla::Literal input = xla::LiteralUtil::MakeTuple({&input_base, &input_grad2}); std::unique_ptr param0_data = - client_->TransferToServer(*input).ConsumeValueOrDie(); + client_->TransferToServer(input).ConsumeValueOrDie(); std::unique_ptr actual = client_->Execute(*result.computation, {param0_data.get()}) .ConsumeValueOrDie(); - std::unique_ptr actual_literal = - client_->Transfer(*actual).ConsumeValueOrDie(); - - std::unique_ptr output_read = - xla::LiteralUtil::CreateR0(42); - std::unique_ptr output_base = - xla::LiteralUtil::CreateR1({7, 42}); - std::unique_ptr output_grad1 = - xla::LiteralUtil::CreateR1({0, 1}); - std::unique_ptr output_grad2 = - xla::LiteralUtil::CreateR1({-3, 101}); - std::unique_ptr output_resource = xla::LiteralUtil::MakeTuple( - {output_base.get(), output_grad1.get(), output_grad2.get()}); - std::unique_ptr expected_literal = - xla::LiteralUtil::MakeTuple({output_read.get(), output_resource.get()}); - EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); + xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie(); + + xla::Literal output_read = xla::LiteralUtil::CreateR0(42); + xla::Literal output_base = xla::LiteralUtil::CreateR1({7, 42}); + xla::Literal output_grad1 = xla::LiteralUtil::CreateR1({0, 1}); + xla::Literal output_grad2 = xla::LiteralUtil::CreateR1({-3, 101}); + xla::Literal output_resource = + xla::LiteralUtil::MakeTuple({&output_base, &output_grad1, &output_grad2}); + xla::Literal expected_literal = + xla::LiteralUtil::MakeTuple({&output_read, &output_resource}); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal)); } // Tests compilation and execution of a graph that adds two tensors. @@ -861,6 +848,28 @@ TEST_F(XlaCompilerTest, LocalFunctionWithWrongArgumentsFail) { << status.error_message(); } +void RunAndCheckVariablesComputation( + xla::Client* client, const XlaCompiler::CompilationResult& result) { + xla::Literal param0_literal = xla::LiteralUtil::CreateR1({7, 42}); + xla::Literal param1_literal = xla::LiteralUtil::CreateR1({-3, 101}); + std::unique_ptr param0_data = + client->TransferToServer(param0_literal).ConsumeValueOrDie(); + std::unique_ptr param1_data = + client->TransferToServer(param1_literal).ConsumeValueOrDie(); + + std::unique_ptr actual = + client + ->Execute(*result.computation, {param0_data.get(), param1_data.get()}) + .ConsumeValueOrDie(); + xla::Literal actual_literal = client->Transfer(*actual).ConsumeValueOrDie(); + + xla::Literal expected0 = xla::LiteralUtil::CreateR1({5, 144}); + xla::Literal expected1 = xla::LiteralUtil::CreateR1({4, 143}); + xla::Literal expected_literal = + xla::LiteralUtil::MakeTuple({&expected0, &expected1}); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal)); +} + // Tests a simple graph that reads and writes a variable. TEST_F(XlaCompilerTest, Variables) { Scope scope = Scope::NewRootScope().ExitOnError(); @@ -892,34 +901,85 @@ TEST_F(XlaCompilerTest, Variables) { // Compiles the graph. XlaCompiler compiler(DefaultOptions()); + XlaCompiler::CompilationResult result; + TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add", + std::move(graph), args, &result)); + RunAndCheckVariablesComputation(client_, result); +} + +// Tests a simple graph that reads and writes a variable. +TEST_F(XlaCompilerTest, ReturnResourceHandleOnly) { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 0); + auto d = ops::_Retval(scope.WithOpName("D"), var, 0); + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(scope.ToGraph(graph.get())); + + // Builds a description of the arguments. + std::vector args(1); + args[0].kind = XlaCompiler::Argument::kResource; + args[0].resource_kind = XlaResource::kVariable; + args[0].initialized = true; + args[0].type = DT_INT32; + args[0].shape = TensorShape({2}); + + // Compiles the graph. + XlaCompiler compiler(DefaultOptions()); + XlaCompiler::CompilationResult result; TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add", std::move(graph), args, &result)); // Tests that the generated computation works. - std::unique_ptr param0_literal = - xla::LiteralUtil::CreateR1({7, 42}); - std::unique_ptr param1_literal = - xla::LiteralUtil::CreateR1({-3, 101}); - std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + xla::Literal param1_literal = xla::LiteralUtil::CreateR1({-3, 101}); std::unique_ptr param1_data = - client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); + client_->TransferToServer(param1_literal).ConsumeValueOrDie(); std::unique_ptr actual = - client_ - ->Execute(*result.computation, {param0_data.get(), param1_data.get()}) + client_->Execute(*result.computation, {param1_data.get()}) .ConsumeValueOrDie(); - std::unique_ptr actual_literal = - client_->Transfer(*actual).ConsumeValueOrDie(); - - std::unique_ptr expected0 = - xla::LiteralUtil::CreateR1({5, 144}); - std::unique_ptr expected1 = - xla::LiteralUtil::CreateR1({4, 143}); - std::unique_ptr expected_literal = - xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()}); - EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); + xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie(); + + xla::Literal expected_literal = xla::LiteralUtil::MakeTuple({}); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal)); +} + +TEST_F(XlaCompilerTest, ReturnResourceHandle) { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0); + auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 1); + // Adds an identity op around the resource to make sure identity ops propagate + // resources correctly. + auto identity = ops::Identity(scope.WithOpName("VIdentity"), var); + auto write = ops::AssignAddVariableOp(scope, identity, a); + auto read = ops::ReadVariableOp( + scope.WithControlDependencies(std::vector{write}), var, + DT_INT32); + auto read_plus_one = ops::Add(scope, read, ops::Const(scope, 1)); + auto r = ops::_Retval(scope.WithOpName("R"), var, 0); + auto d = ops::_Retval(scope.WithOpName("D"), read_plus_one, 1); + + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(scope.ToGraph(graph.get())); + + // Builds a description of the arguments. + std::vector args(2); + args[0].kind = XlaCompiler::Argument::kParameter; + args[0].type = DT_INT32; + args[0].shape = TensorShape({2}); + args[1].kind = XlaCompiler::Argument::kResource; + args[1].resource_kind = XlaResource::kVariable; + args[1].initialized = true; + args[1].type = DT_INT32; + args[1].shape = TensorShape({2}); + + // Compiles the graph. + XlaCompiler compiler(DefaultOptions()); + + XlaCompiler::CompilationResult result; + TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add", + std::move(graph), args, &result)); + RunAndCheckVariablesComputation(client_, result); } xla::StatusOr> BuildTestGraph() { @@ -985,29 +1045,27 @@ TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) { xla::ShapeUtil::MakeShape(xla::S32, {4})}))); // Tests that the generated computation works. - std::unique_ptr param0_literal = + xla::Literal param0_literal = xla::LiteralUtil::CreateR2({{4, 55}, {1, -3}}); - std::unique_ptr param1_literal = + xla::Literal param1_literal = xla::LiteralUtil::CreateR1({22, 11, 33, 404}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); std::unique_ptr param1_data = - client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); + client_->TransferToServer(param1_literal).ConsumeValueOrDie(); std::unique_ptr actual = client_ ->Execute(*result.computation, {param0_data.get(), param1_data.get()}) .ConsumeValueOrDie(); - std::unique_ptr actual_literal = - client_->Transfer(*actual).ConsumeValueOrDie(); + xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie(); - std::unique_ptr expected0 = + xla::Literal expected0 = xla::LiteralUtil::CreateR2({{27, 67}, {35, 402}}); - std::unique_ptr expected1 = - xla::LiteralUtil::CreateR1({26, 66, 34, 401}); - std::unique_ptr expected_literal = - xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()}); - EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); + xla::Literal expected1 = xla::LiteralUtil::CreateR1({26, 66, 34, 401}); + xla::Literal expected_literal = + xla::LiteralUtil::MakeTuple({&expected0, &expected1}); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal)); } TEST_F(XlaCompilerTest, ArgRetvalShapeRepresentationFunction) { @@ -1054,29 +1112,26 @@ TEST_F(XlaCompilerTest, ArgRetvalShapeRepresentationFunction) { xla::ShapeUtil::MakeShape(xla::S32, {4})}))); // Tests that the generated computation works. - std::unique_ptr param0_literal = + xla::Literal param0_literal = xla::LiteralUtil::CreateR1({4, 55, 1, -3}); - std::unique_ptr param1_literal = + xla::Literal param1_literal = xla::LiteralUtil::CreateR1({22, 11, 33, 404}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); std::unique_ptr param1_data = - client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); + client_->TransferToServer(param1_literal).ConsumeValueOrDie(); std::unique_ptr actual = client_ ->Execute(*result.computation, {param0_data.get(), param1_data.get()}) .ConsumeValueOrDie(); - std::unique_ptr actual_literal = - client_->Transfer(*actual).ConsumeValueOrDie(); - - std::unique_ptr expected0 = - xla::LiteralUtil::CreateR1({27, 67, 35, 402}); - std::unique_ptr expected1 = - xla::LiteralUtil::CreateR1({26, 66, 34, 401}); - std::unique_ptr expected_literal = - xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()}); - EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); + xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie(); + + xla::Literal expected0 = xla::LiteralUtil::CreateR1({27, 67, 35, 402}); + xla::Literal expected1 = xla::LiteralUtil::CreateR1({26, 66, 34, 401}); + xla::Literal expected_literal = + xla::LiteralUtil::MakeTuple({&expected0, &expected1}); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal)); } // Tests a graph which has a function with an invalid op. @@ -1171,25 +1226,73 @@ TEST_F(XlaCompilerTest, SingleOpWithoutInputs) { std::unique_ptr graph_copy(new Graph(OpRegistry::Global())); CopyGraph(*graph, graph_copy.get()); XlaCompiler::CompilationResult result; - status = compiler.CompileGraph(XlaCompiler::CompileOptions(), "NoOp", - std::move(graph_copy), args, &result); - ASSERT_FALSE(status.ok()); - EXPECT_TRUE( - absl::StrContains(status.error_message(), - "The following nodes are unreachable " - "from the source in the graph: {{node NoOp}}")) - << status.error_message(); + TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "NoOp", + std::move(graph_copy), args, &result)); } +} + +class DummySideEffectingOp : public XlaOpKernel { + public: + explicit DummySideEffectingOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + void Compile(XlaOpKernelContext* ctx) override { + OP_REQUIRES_OK(ctx, ctx->compiler()->SetNodeToken( + name(), xla::CreateToken(ctx->builder()))); + } +}; + +REGISTER_OP("DummySideEffectingOp"); + +REGISTER_XLA_OP(Name("DummySideEffectingOp"), DummySideEffectingOp); + +TEST_F(XlaCompilerTest, TokenInputAndOutput) { + std::unique_ptr graph(new Graph(OpRegistry::Global())); + NodeDef side_effecting_op; + side_effecting_op.set_name("DummySideEffectingOp"); + side_effecting_op.set_op("DummySideEffectingOp"); + AddNodeAttr(kXlaTokenInputNodesAttrName, + std::vector{kXlaTokenArgNodeName}, &side_effecting_op); + Status status; + graph->AddNode(side_effecting_op, &status); + TF_ASSERT_OK(status); + EXPECT_TRUE(FixupSourceAndSinkEdges(graph.get())); - // Fix control edges for NoOp. + const std::vector empty_args; { + // The case for entry computation: we don't add token input/output. Instead, + // we use CreateToken HLO to create the entry token. + XlaCompiler::CompileOptions options; + options.is_entry_computation = true; + options.add_token_input_output = false; + XlaCompiler compiler(DefaultOptions()); + std::unique_ptr graph_copy(new Graph(OpRegistry::Global())); CopyGraph(*graph, graph_copy.get()); - EXPECT_TRUE(FixupSourceAndSinkEdges(graph_copy.get())); XlaCompiler::CompilationResult result; - TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "NoOp", - std::move(graph_copy), args, &result)); - EXPECT_EQ(0, result.resource_updates.size()); + TF_ASSERT_OK(compiler.CompileGraph(options, "NoOp", std::move(graph_copy), + empty_args, &result)); + EXPECT_EQ(result.xla_input_shapes.size(), 0); + EXPECT_TRUE(xla::ShapeUtil::IsTuple(result.xla_output_shape)); + EXPECT_EQ(xla::ShapeUtil::TupleElementCount(result.xla_output_shape), 0); + } + { + // The case for non-entry computation (e.g. while loop body). We add token + // input/output. + XlaCompiler::CompileOptions options; + options.is_entry_computation = false; + options.add_token_input_output = true; + XlaCompiler compiler(DefaultOptions()); + + std::unique_ptr graph_copy(new Graph(OpRegistry::Global())); + CopyGraph(*graph, graph_copy.get()); + XlaCompiler::CompilationResult result; + TF_ASSERT_OK(compiler.CompileGraph(options, "NoOp", std::move(graph_copy), + empty_args, &result)); + EXPECT_EQ(result.xla_input_shapes.size(), 1); + EXPECT_TRUE(xla::ShapeUtil::IsToken(result.xla_input_shapes[0])); + EXPECT_TRUE(xla::ShapeUtil::IsTuple(result.xla_output_shape)); + EXPECT_EQ(xla::ShapeUtil::TupleElementCount(result.xla_output_shape), 1); + EXPECT_TRUE(xla::ShapeUtil::IsToken( + xla::ShapeUtil::GetTupleElementShape(result.xla_output_shape, 0))); } } diff --git a/tensorflow/compiler/tf2xla/xla_context.cc b/tensorflow/compiler/tf2xla/xla_context.cc index b24e3aabbe6ba858a8bfb4dd435726984cc7b0f5..2095a6b8099f48a867ec2c7c7d6e84d8f2426dce 100644 --- a/tensorflow/compiler/tf2xla/xla_context.cc +++ b/tensorflow/compiler/tf2xla/xla_context.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/tf2xla/literal_util.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" @@ -31,8 +32,6 @@ limitations under the License. #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/common_runtime/dma_helper.h" -#include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" namespace tensorflow { @@ -107,6 +106,30 @@ Status XlaContext::AddConstRetval(int retval_index, DataType dtype, return Status::OK(); } +Status XlaContext::AddResourceRetval(int retval_index, XlaResource* resource) { + VLOG(1) << "Adding retval index " << retval_index << " with resource " + << resource->name() << ":" << resource->shape().DebugString() + << " to XLA computation"; + if (retvals_.size() <= retval_index) { + retvals_.resize(retval_index + 1); + } + XlaExpression e; + e.set_resource(resource); + retvals_[retval_index] = Retval{DT_RESOURCE, resource->shape(), e}; + return Status::OK(); +} + +Status XlaContext::AppendTokenRetval(const xla::XlaOp& token) { + VLOG(1) << "Adding retval index " << retvals_.size() + << " with token to XLA computation"; + XlaExpression e; + e.set_handle(token); + // We use DT_INVALID because there is no TF DataType which corresponds to XLA + // token. XlaCompiler handles this case separately, so putting it here is OK. + retvals_.push_back(Retval{DT_INVALID, TensorShape(), e}); + return Status::OK(); +} + xla::XlaBuilder* XlaContext::builder() { return builder_; } Status XlaContext::CreateResource( @@ -115,7 +138,8 @@ Status XlaContext::CreateResource( const std::set& tensor_array_gradients, XlaResource** resource) { resources_.emplace_back( new XlaResource(kind, arg_num, std::move(name), type, std::move(shape), - handle, tensor_array_size, tensor_array_gradients)); + handle, tensor_array_size, tensor_array_gradients, + /*tensor_array_multiple_writes_aggregate=*/false)); *resource = resources_.back().get(); return Status::OK(); } diff --git a/tensorflow/compiler/tf2xla/xla_context.h b/tensorflow/compiler/tf2xla/xla_context.h index 3db37afdba71342cfb20af8841a40cb54709ca73..d7dbdc957f0e7969db5098b815381866cdc71ab6 100644 --- a/tensorflow/compiler/tf2xla/xla_context.h +++ b/tensorflow/compiler/tf2xla/xla_context.h @@ -86,6 +86,12 @@ class XlaContext : public ResourceBase { Status AddConstRetval(int retval_index, DataType dtype, const xla::LiteralSlice& literal); + // As for Retval, but for return values that are resource handles. + Status AddResourceRetval(int retval_index, XlaResource* resource); + + // As for Retval, but for return values that are XLA tokens. + Status AppendTokenRetval(const xla::XlaOp& token); + // Creates a resource with resource `kind` and initial value `handle`. `name` // is a descriptive name for use in error messages. See the `XlaResource` // constructor for a description of the remaining arguments. diff --git a/tensorflow/compiler/tf2xla/xla_cpu_backend.cc b/tensorflow/compiler/tf2xla/xla_cpu_backend.cc index 23d04d43b358e858ad1ab2463322ce0ab93b23c2..bc44301d405102921de21da4bd9407032783838c 100644 --- a/tensorflow/compiler/tf2xla/xla_cpu_backend.cc +++ b/tensorflow/compiler/tf2xla/xla_cpu_backend.cc @@ -20,21 +20,6 @@ limitations under the License. namespace tensorflow { bool CpuOpFilter(KernelDef* kdef) { - // TODO(b/34339814): implement inverse erf for double types and remove this - // workaround. - if (kdef->op() == "RandomStandardNormal") { - kdef->clear_constraint(); - // Change the type constraint to permit only DTD_FLOAT. - KernelDef::AttrConstraint* attr_constraint = kdef->add_constraint(); - attr_constraint->set_name("dtype"); - attr_constraint->mutable_allowed_values()->mutable_list()->add_type( - DT_FLOAT); - return true; - } - // TODO(b/26783907): The CPU backend currently does not implement sort. - if (kdef->op() == "XlaSort" || kdef->op() == "TopKV2") { - return false; - } if (kdef->op() == "Const") { AddDtypeToKernalDefConstraint("dtype", DT_STRING, kdef); } diff --git a/tensorflow/compiler/tf2xla/xla_helpers.cc b/tensorflow/compiler/tf2xla/xla_helpers.cc index 8efb3d55c88757b9366bdf9622287bdd0a72e295..9a34cd8c6ae2dc6d52a3cc69168df96f5322c6da 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.cc +++ b/tensorflow/compiler/tf2xla/xla_helpers.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/lib/util.h" +#include "absl/types/span.h" #include "tensorflow/compiler/tf2xla/literal_util.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" @@ -31,7 +32,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/array_slice.h" namespace tensorflow { @@ -119,7 +119,7 @@ xla::XlaOp XlaHelpers::FloatLiteral(xla::XlaBuilder* b, DataType data_type, } /* static */ Status XlaHelpers::ReshapeLiteral( - const xla::Literal& input, gtl::ArraySlice dimensions, + const xla::Literal& input, absl::Span dimensions, xla::Literal* output) { if (xla::ShapeUtil::IsTuple(input.shape())) { return errors::InvalidArgument("ReshapeLiteral does not support tuples."); diff --git a/tensorflow/compiler/tf2xla/xla_helpers.h b/tensorflow/compiler/tf2xla/xla_helpers.h index e6522157a535fc3e4ec96cb0496b6be2e525c336..39578144caaadf293d24ea91aa874e56e27ecc01 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.h +++ b/tensorflow/compiler/tf2xla/xla_helpers.h @@ -18,10 +18,10 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_HELPERS_H_ #define TENSORFLOW_COMPILER_TF2XLA_XLA_HELPERS_H_ +#include "absl/types/span.h" #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/lib/gtl/array_slice.h" namespace tensorflow { @@ -50,7 +50,7 @@ class XlaHelpers { // Reshapes literal 'input' to have 'shape'. Both the original shape and // 'shape' must contain the same number of elements. static Status ReshapeLiteral(const xla::Literal& input, - gtl::ArraySlice shape, + absl::Span shape, xla::Literal* output); // Returns the argmax of `input` along `axis`. `output_type` is the type to diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc index 9e8f5f2a1adc4dd0dadf6c8f88c5e18dd0d1dc00..dd3498ef7aa242d3ad946cae5f60bc2c8853a342 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -67,7 +67,7 @@ const xla::XlaOp& XlaOpKernelContext::Input(int index) { return GetComputationFromTensor(context_->input(index)); } -const xla::XlaOp& XlaOpKernelContext::Input(StringPiece name) { +const xla::XlaOp& XlaOpKernelContext::Input(absl::string_view name) { return GetComputationFromTensor(GetInputTensorByName(name)); } @@ -75,7 +75,7 @@ TensorShape XlaOpKernelContext::InputShape(int index) { return context_->input(index).shape(); } -TensorShape XlaOpKernelContext::InputShape(StringPiece name) { +TensorShape XlaOpKernelContext::InputShape(absl::string_view name) { return GetInputTensorByName(name).shape(); } @@ -83,6 +83,10 @@ DataType XlaOpKernelContext::input_type(int index) const { return context_->input(index).dtype(); } +DataType XlaOpKernelContext::InputType(absl::string_view name) { + return GetInputTensorByName(name).dtype(); +} + xla::PrimitiveType XlaOpKernelContext::input_xla_type(int index) { xla::PrimitiveType type; Status status = DataTypeToPrimitiveType(input_type(index), &type); @@ -100,7 +104,7 @@ Status XlaOpKernelContext::ConstantInput(int index, } static xla::StatusOr InputIndex(XlaOpKernelContext* context, - StringPiece name) { + absl::string_view name) { int start, stop; TF_RETURN_IF_ERROR(context->op_kernel().InputRange(name, &start, &stop)); if (stop != start + 1) { @@ -112,14 +116,14 @@ static xla::StatusOr InputIndex(XlaOpKernelContext* context, return start; } -Status XlaOpKernelContext::ConstantInput(StringPiece name, +Status XlaOpKernelContext::ConstantInput(absl::string_view name, xla::Literal* constant_literal) { TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name)); return ConstantInput(index, constant_literal); } Status XlaOpKernelContext::ConstantInputReshaped( - int index, gtl::ArraySlice new_dims, + int index, absl::Span new_dims, xla::Literal* constant_literal) { const Tensor& tensor = context_->input(index); TensorShape new_shape(new_dims); @@ -213,16 +217,15 @@ Status XlaOpKernelContext::ConstantInputReshaped( context_->op_kernel().name(), " input ", index, ".\nError: ", constant_graph.status().error_message()); } - xla::StatusOr> computed = - compiler()->client()->ComputeConstant(constant_graph.ValueOrDie(), - &layout); + xla::StatusOr computed = compiler()->client()->ComputeConstant( + constant_graph.ValueOrDie(), &layout); if (!computed.ok()) { return errors::Internal("Error evaluating ", context_->op_kernel().name(), " input ", index, - "as a compile-time constant.\nError: ", + " as a compile-time constant.\nError: ", computed.status().error_message()); } - *constant_literal = std::move(*computed.ValueOrDie()); + *constant_literal = std::move(computed).ValueOrDie(); return Status::OK(); } @@ -265,7 +268,7 @@ Status XlaOpKernelContext::ConstantInputAsIntScalar(int index, int64* out) { return LiteralToInt64Scalar(literal, out); } -Status XlaOpKernelContext::ConstantInputAsIntScalar(StringPiece name, +Status XlaOpKernelContext::ConstantInputAsIntScalar(absl::string_view name, int64* out) { TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name)); return ConstantInputAsIntScalar(index, out); @@ -305,7 +308,7 @@ Status XlaOpKernelContext::ConstantInputAsIntVector(int index, return LiteralToInt64Vector(literal, out); } -Status XlaOpKernelContext::ConstantInputAsIntVector(StringPiece name, +Status XlaOpKernelContext::ConstantInputAsIntVector(absl::string_view name, std::vector* out) { TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name)); return ConstantInputAsIntVector(index, out); @@ -344,7 +347,7 @@ Status XlaOpKernelContext::ConstantInputAsInt64Literal(int index, } } -Status XlaOpKernelContext::ConstantInputAsInt64Literal(StringPiece name, +Status XlaOpKernelContext::ConstantInputAsInt64Literal(absl::string_view name, xla::Literal* out) { TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name)); return ConstantInputAsInt64Literal(index, out); @@ -361,7 +364,7 @@ Status XlaOpKernelContext::ConstantInputAsShape(int index, TensorShape* shape) { return Status::OK(); } -Status XlaOpKernelContext::InputList(StringPiece name, +Status XlaOpKernelContext::InputList(absl::string_view name, std::vector* handles, std::vector* shapes) { OpInputList inputs; @@ -376,7 +379,7 @@ Status XlaOpKernelContext::InputList(StringPiece name, } Status XlaOpKernelContext::ConstantInputList( - StringPiece name, std::vector* outputs) { + absl::string_view name, std::vector* outputs) { int start, stop; TF_RETURN_IF_ERROR(op_kernel().InputRange(name, &start, &stop)); outputs->resize(stop - start); @@ -429,8 +432,8 @@ Status XlaOpKernelContext::ReadVariableInput(int index, DataType type, value); } -Status XlaOpKernelContext::ReadVariableInput(StringPiece name, DataType type, - TensorShape* shape, +Status XlaOpKernelContext::ReadVariableInput(absl::string_view name, + DataType type, TensorShape* shape, xla::XlaOp* value) { return ReadVariableInputTensor(GetInputTensorByName(name), type, context_, shape, value); @@ -452,23 +455,43 @@ Status XlaOpKernelContext::GetVariableTypeAndShape(int index, DataType* type, return Status::OK(); } +Status XlaOpKernelContext::allocate_output(int index, const xla::Shape& shape, + Tensor** output) { + // The step's default allocator is the dummy XlaCompilationAllocator which + // simply allocates a metadata buffer to hold the expression to which it + // corresponds. + if (expected_output_dtype(index) == DT_VARIANT) { + // tensor_data() is not supported for variant Tensor (i.e., + // DataTypeCanUseMemcpy is false for DT_VARIANT), and so storing the + // XlaExpression inside the Tensor's tensor_data() does not work for + // variant. Instead construct a uint8 tensor and store the expression in its + // value. + // TODO(jpienaar): This should be refactored to stop masquerading + // XlaExpressions as Tensors. + *output = new Tensor(); + TensorShape tensor_shape; + TF_RETURN_IF_ERROR( + context_->allocate_temp(DT_UINT8, tensor_shape, *output)); + context_->set_output(index, **output); + } else { + TensorShape tensor_shape; + TF_RETURN_IF_ERROR(XLAShapeToTensorShape(shape, &tensor_shape)); + TF_RETURN_IF_ERROR(context_->allocate_output(index, tensor_shape, output)); + } + return Status::OK(); +} + void XlaOpKernelContext::SetOutput(int index, const xla::XlaOp& handle) { // Makes the host Tensor that will refer to the expression. Tensor* output = nullptr; - auto shape = builder()->GetShape(handle); - if (!shape.ok()) { - SetStatus(shape.status()); + auto shape_or = builder()->GetShape(handle); + if (!shape_or.ok()) { + SetStatus(shape_or.status()); return; } - // The step's default allocator is the dummy XlaCompilationAllocator which - // simply allocates a metadata buffer to hold the expression to which it - // corresponds. - TensorShape tensor_shape; - OP_REQUIRES_OK(context_, - XLAShapeToTensorShape(shape.ValueOrDie(), &tensor_shape)); OP_REQUIRES_OK(context_, - context_->allocate_output(index, tensor_shape, &output)); + allocate_output(index, shape_or.ValueOrDie(), &output)); // The expression is stored in the tensor's data buffer. Fill in the // fields now. @@ -564,7 +587,7 @@ Status XlaOpKernelContext::AssignVariable(int input_index, DataType type, handle, builder()); } -Status XlaOpKernelContext::AssignVariable(StringPiece name, DataType type, +Status XlaOpKernelContext::AssignVariable(absl::string_view name, DataType type, xla::XlaOp handle) { TF_RET_CHECK(handle.valid()); return AssignVariableTensor(GetInputTensorByName(name), type, context_, @@ -610,7 +633,7 @@ const xla::XlaComputation* XlaOpKernelContext::GetOrCreateMul( return XlaContext::Get(context_).GetOrCreateMul(type); } -const Tensor& XlaOpKernelContext::GetInputTensorByName(StringPiece name) { +const Tensor& XlaOpKernelContext::GetInputTensorByName(absl::string_view name) { const Tensor* tensor; CHECK(context_->input(name, &tensor).ok()); return *tensor; diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h index 3e26ba4f015ee81d1e880f9c4ee1e1a3665af452..aa00a454968ad29495e34dc080e55b62bb0b5f7b 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.h +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h @@ -71,6 +71,9 @@ class XlaOpKernelContext { // Returns the type of input `index`. DataType input_type(int index) const; + // Returns the type of input `name`. + DataType InputType(absl::string_view name); + // Returns the type of input `index` as an xla::PrimitiveType. If the type // is not representable as an XLA type, sets an error status and returns // xla::PRIMITIVE_TYPE_INVALID. @@ -79,15 +82,15 @@ class XlaOpKernelContext { // Returns the shape of input `index`. TensorShape InputShape(int index); - // Returns the shape of input `name`. - TensorShape InputShape(StringPiece name); + // Returns the shape of input with name `name`. + TensorShape InputShape(absl::string_view name); // Returns input `index` as a XlaOp. Unlike // OpKernelContext::Input returns a symbolic value rather than a concrete // Tensor. const xla::XlaOp& Input(int index); // Returns input `name` as a XlaOp. - const xla::XlaOp& Input(StringPiece name); + const xla::XlaOp& Input(absl::string_view name); // Returns true if all inputs are the same shape, otherwise sets the // status to a non-OK value and returns false. @@ -97,7 +100,7 @@ class XlaOpKernelContext { // Returns the named list-valued immutable input in "list", as // defined in the OpDef. If the named output is not list-valued, // returns a one-element list. - Status InputList(StringPiece name, std::vector* handles, + Status InputList(absl::string_view name, std::vector* handles, std::vector* shapes); // Helper methods for constant inputs. @@ -106,26 +109,27 @@ class XlaOpKernelContext { // expression cannot be evaluated, e.g., because it depends on unbound // parameters, returns a non-OK status. Status ConstantInput(int index, xla::Literal* constant_literal); - Status ConstantInput(StringPiece name, xla::Literal* constant_literal); + Status ConstantInput(absl::string_view name, xla::Literal* constant_literal); // Evaluates input `index`, reshapes it to `new_shape` if new_shape != // InputShape(index), and stores it in `*constant_literal`. If the input // cannot be evaluated, e.g., because it depends on unbound parameters, // returns a non-Ok status. If InputShape(index).num_elements() != // new_shape.num_elements(), returns an error status. - Status ConstantInputReshaped(int index, gtl::ArraySlice new_shape, + Status ConstantInputReshaped(int index, absl::Span new_dims, xla::Literal* constant_literal); // Converts a constant scalar int32 or int64 tensor into an int64. Status ConstantInputAsIntScalar(int index, int64* out); - Status ConstantInputAsIntScalar(StringPiece name, int64* out); + Status ConstantInputAsIntScalar(absl::string_view name, int64* out); // Converts a constant scalar float32 or float64 tensor into a float64. Status ConstantInputAsFloatScalar(int index, double* out); // Converts a constant 1D int32 or int64 tensor into a vector of int64s. Status ConstantInputAsIntVector(int index, std::vector* out); - Status ConstantInputAsIntVector(StringPiece name, std::vector* out); + Status ConstantInputAsIntVector(absl::string_view name, + std::vector* out); // Reshapes and converts a constant int32 or int64 tensor into a vector of // int64s. @@ -133,7 +137,7 @@ class XlaOpKernelContext { // Converts a constant int32 or int64 Tensor into an xla int64 Literal. Status ConstantInputAsInt64Literal(int index, xla::Literal* out); - Status ConstantInputAsInt64Literal(StringPiece name, xla::Literal* out); + Status ConstantInputAsInt64Literal(absl::string_view name, xla::Literal* out); // Converts a constant 1D int32 or int64 tensor into a TensorShape. Status ConstantInputAsShape(int index, TensorShape* shape); @@ -141,7 +145,7 @@ class XlaOpKernelContext { // Returns the named list-valued immutable input in "list", as // defined in the OpDef. If the named output is not list-valued, // returns a one-element list. - Status ConstantInputList(StringPiece name, + Status ConstantInputList(absl::string_view name, std::vector* literals); // Outputs @@ -190,8 +194,8 @@ class XlaOpKernelContext { xla::XlaOp* value); // Reads the current value of the resouce variable referred to by input // `name`. - Status ReadVariableInput(StringPiece name, DataType type, TensorShape* shape, - xla::XlaOp* value); + Status ReadVariableInput(absl::string_view name, DataType type, + TensorShape* shape, xla::XlaOp* value); // Assigns the value `handle` to the variable referenced by input // `input_index`. The variable must be of `type`. Returns an error if the @@ -199,7 +203,8 @@ class XlaOpKernelContext { // different shape. Status AssignVariable(int input_index, DataType type, xla::XlaOp handle); // Assigns the value `handle` to the variable referenced by input `name`. - Status AssignVariable(StringPiece name, DataType type, xla::XlaOp handle); + Status AssignVariable(absl::string_view name, DataType type, + xla::XlaOp handle); // Helper routines for the OP_REQUIRES macros void CtxFailure(const Status& s); @@ -248,7 +253,12 @@ class XlaOpKernelContext { private: // Returns the tensor of input `name`. - const Tensor& GetInputTensorByName(StringPiece name); + const Tensor& GetInputTensorByName(absl::string_view name); + + // Wraps OpKernelContext's allocate_output method while providing special + // behavior for DT_VARIANT: a variant is treated as DT_UINT8 scalar as the + // type to allow mapping for variant to more generic types. + Status allocate_output(int index, const xla::Shape& shape, Tensor** output); OpKernelContext* const context_; }; diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.cc b/tensorflow/compiler/tf2xla/xla_op_registry.cc index e25c7e8c9ea7590fe11564c10c9e1f49eebe36df..91d48125f1d21092db7e5f9307e44af9c16e4e2b 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.cc +++ b/tensorflow/compiler/tf2xla/xla_op_registry.cc @@ -90,6 +90,11 @@ XlaOpRegistry::~XlaOpRegistry() = default; << " have incompatible compile time constant inputs."; return false; } + if (x.is_metadata_op != y.is_metadata_op) { + LOG(WARNING) << "Registrations of " << x.name + << " have incompatible values for is_metadata_op."; + return false; + } return true; } @@ -105,7 +110,7 @@ XlaOpRegistry::~XlaOpRegistry() = default; /* static */ void XlaOpRegistry::RegisterBackend( const string& compilation_device_name, - gtl::ArraySlice supported_types, BackendOpFilter op_filter) { + absl::Span supported_types, BackendOpFilter op_filter) { XlaOpRegistry& registry = Instance(); mutex_lock lock(registry.mutex_); auto result = registry.backends_.emplace(compilation_device_name, Backend()); @@ -350,6 +355,20 @@ XlaOpRegistry::CompileTimeConstantInputs(const string& op) { return &it->second.front()->compile_time_constant_inputs; } +/*static*/ bool XlaOpRegistry::IsMetadataOp(const string& op) { + XlaOpRegistry& registry = Instance(); + mutex_lock lock(registry.mutex_); + auto it = registry.ops_.find(op); + if (it == registry.ops_.end() || it->second.empty()) { + return false; + } + + // The test in IsCompatible ensures that if there are multiple matching + // registrations for this op name, they all have the same value of + // is_metadata_op, so only the first match is returned. + return it->second.front()->is_metadata_op; +} + std::vector XlaOpRegistry::BackendNames() { std::vector names; XlaOpRegistry& registry = Instance(); @@ -371,28 +390,30 @@ XlaOpRegistry& XlaOpRegistry::Instance() { return *r; } -XlaOpRegistrationBuilder::XlaOpRegistrationBuilder(StringPiece name) { +XlaOpRegistrationBuilder::XlaOpRegistrationBuilder(absl::string_view name) { registration_.reset(new XlaOpRegistry::OpRegistration); - registration_->name = std::string(name); + registration_->name = string(name); } -XlaOpRegistrationBuilder XlaOpRegistrationBuilder::Name(StringPiece name) { +XlaOpRegistrationBuilder XlaOpRegistrationBuilder::Name( + absl::string_view name) { XlaOpRegistrationBuilder registration(name); return registration; } XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::Device( - gtl::ArraySlice devices) { + absl::Span devices) { registration_->has_device_whitelist = true; - for (StringPiece device : devices) { - registration_->device_whitelist.insert(std::string(device)); + for (absl::string_view device : devices) { + registration_->device_whitelist.emplace(device); } return *this; } -XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::Device(StringPiece device) { +XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::Device( + absl::string_view device) { registration_->has_device_whitelist = true; - registration_->device_whitelist.insert(std::string(device)); + registration_->device_whitelist.emplace(device); return *this; } @@ -407,17 +428,17 @@ XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::AllowResourceTypes() { } XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::TypeConstraint( - StringPiece attr_name, DataType allowed) { + absl::string_view attr_name, DataType allowed) { std::set& types = - registration_->type_constraints[std::string(attr_name)]; + registration_->type_constraints[string(attr_name)]; types.insert(allowed); return *this; } XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::TypeConstraint( - StringPiece attr_name, gtl::ArraySlice allowed) { + absl::string_view attr_name, absl::Span allowed) { std::set& types = - registration_->type_constraints[std::string(attr_name)]; + registration_->type_constraints[string(attr_name)]; for (DataType t : allowed) { types.insert(t); } @@ -425,8 +446,13 @@ XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::TypeConstraint( } XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::CompileTimeConstInput( - StringPiece input_name) { - registration_->compile_time_constant_inputs.insert(std::string(input_name)); + absl::string_view input_name) { + registration_->compile_time_constant_inputs.emplace(input_name); + return *this; +} + +XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::IsMetadataOp() { + registration_->is_metadata_op = true; return *this; } @@ -452,10 +478,10 @@ XlaOpRegistrar::XlaOpRegistrar( } XlaBackendRegistrar::XlaBackendRegistrar( - StringPiece name, gtl::ArraySlice types, + absl::string_view name, absl::Span types, XlaOpRegistry::BackendOpFilter op_filter) { XlaOpRegistry& registry = XlaOpRegistry::Instance(); - registry.RegisterBackend(std::string(name), types, op_filter); + registry.RegisterBackend(string(name), types, op_filter); } } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.h b/tensorflow/compiler/tf2xla/xla_op_registry.h index 6ce0e2580b1a9b75fe72fba931d80c96b3870fce..4b2c2bacd647b3e6fe500a942b116772550195ce 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.h +++ b/tensorflow/compiler/tf2xla/xla_op_registry.h @@ -47,17 +47,18 @@ extern const char* const DEVICE_XLA_GPU; constexpr std::array kFloatTypes = { {DT_HALF, DT_FLOAT, DT_DOUBLE, DT_BFLOAT16}}; -constexpr std::array kNumericTypes = { - {DT_UINT32, DT_UINT64, DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, - DT_COMPLEX64, DT_BFLOAT16}}; +constexpr std::array kNumericTypes = { + {DT_UINT8, DT_UINT32, DT_UINT64, DT_INT8, DT_INT32, DT_INT64, DT_HALF, + DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BFLOAT16}}; -constexpr std::array kCpuAllTypes = { - {DT_UINT32, DT_UINT64, DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, - DT_COMPLEX64, DT_BOOL}}; +constexpr std::array kCpuAllTypes = { + {DT_UINT8, DT_QUINT8, DT_UINT32, DT_UINT64, DT_INT8, DT_QINT8, DT_INT32, + DT_QINT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL}}; -constexpr std::array kGpuAllTypes = { - {DT_UINT32, DT_UINT64, DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, - DT_COMPLEX64, DT_BOOL, DT_BFLOAT16}}; +constexpr std::array kGpuAllTypes = { + {DT_UINT8, DT_QUINT8, DT_UINT32, DT_UINT64, DT_INT8, DT_QINT8, DT_INT32, + DT_QINT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL, + DT_BFLOAT16}}; // Class that manages registrations of operators and devices for the XLA JIT. // Not thread-safe. @@ -94,7 +95,7 @@ class XlaOpRegistry { // the device; it may optionally modify the KernelDef. typedef bool (*BackendOpFilter)(KernelDef* kdef); static void RegisterBackend(const string& compilation_device_name, - gtl::ArraySlice supported_types, + absl::Span supported_types, BackendOpFilter op_filter); // Returns the names of the registered backends. @@ -136,6 +137,10 @@ class XlaOpRegistry { static const std::unordered_set* CompileTimeConstantInputs( const string& op); + // Returns true if `op` is a "metadata" op, one that only looks at the shapes + // of its operands and not their values. + static bool IsMetadataOp(const string& op); + private: friend class XlaBackendRegistrar; friend class XlaOpRegistrar; @@ -192,6 +197,10 @@ class XlaOpRegistry { // Names of arguments that must be compile-time constants. std::unordered_set compile_time_constant_inputs; + // True if this is a "metadata" op, one that only looks at the shapes of its + // operands and not their values. + bool is_metadata_op = false; + // Factory used to build OpKernels that perform symbolic execution. Factory factory; }; @@ -232,19 +241,19 @@ class XlaOpRegistry { class XlaOpRegistrationBuilder { public: // Starts an operator registration chain. - static XlaOpRegistrationBuilder Name(StringPiece name); + static XlaOpRegistrationBuilder Name(absl::string_view name); // Specifies a whitelist of devices on which the operator may run. - XlaOpRegistrationBuilder& Device(StringPiece devices); - XlaOpRegistrationBuilder& Device(gtl::ArraySlice devices); + XlaOpRegistrationBuilder& Device(absl::string_view devices); + XlaOpRegistrationBuilder& Device(absl::Span devices); // Specifies a type constraint for a type variable attribute. Each constraint // specifies the set of types that the type variable may assume. - XlaOpRegistrationBuilder& TypeConstraint(StringPiece attr_name, + XlaOpRegistrationBuilder& TypeConstraint(absl::string_view attr_name, DataType allowed); - XlaOpRegistrationBuilder& TypeConstraint(StringPiece attr_name, - gtl::ArraySlice allowed); + XlaOpRegistrationBuilder& TypeConstraint(absl::string_view attr_name, + absl::Span allowed); // Specifies that a dummy copy of this operator should not be registered on // XLA_* devices, but may be used during compilation. @@ -254,13 +263,17 @@ class XlaOpRegistrationBuilder { XlaOpRegistrationBuilder& AllowResourceTypes(); // Mark 'input_name' as an argument whose value must be known at compile-time. - XlaOpRegistrationBuilder& CompileTimeConstInput(StringPiece input_name); + XlaOpRegistrationBuilder& CompileTimeConstInput(absl::string_view input_name); + + // Mark this op as a "metadata" op, one that only looks at the shapes of its + // operands and not their values. + XlaOpRegistrationBuilder& IsMetadataOp(); std::unique_ptr Build( XlaOpRegistry::Factory factory); private: - XlaOpRegistrationBuilder(StringPiece name); + XlaOpRegistrationBuilder(absl::string_view name); std::unique_ptr registration_; }; @@ -288,7 +301,7 @@ class XlaOpRegistrar { class XlaBackendRegistrar { public: - XlaBackendRegistrar(StringPiece name, gtl::ArraySlice types, + XlaBackendRegistrar(absl::string_view name, absl::Span types, XlaOpRegistry::BackendOpFilter op_filter = nullptr); }; diff --git a/tensorflow/compiler/tf2xla/xla_resource.cc b/tensorflow/compiler/tf2xla/xla_resource.cc index 7928fa034725206a752cbfe086d01f15cd235df9..63b09c8f02a60e91576544d13227d29f56d3e88c 100644 --- a/tensorflow/compiler/tf2xla/xla_resource.cc +++ b/tensorflow/compiler/tf2xla/xla_resource.cc @@ -29,7 +29,8 @@ namespace tensorflow { XlaResource::XlaResource(Kind kind, int arg_num, string name, DataType type, TensorShape shape, const xla::XlaOp& initial_value, int64 tensor_array_size, - const std::set& tensor_array_gradients) + const std::set& tensor_array_gradients, + bool tensor_array_multiple_writes_aggregate) : kind_(kind), arg_num_(arg_num), name_(std::move(name)), @@ -37,14 +38,17 @@ XlaResource::XlaResource(Kind kind, int arg_num, string name, DataType type, shape_(std::move(shape)), value_(initial_value), initial_value_(initial_value), - tensor_array_size_(tensor_array_size) { + tensor_array_size_(tensor_array_size), + tensor_array_multiple_writes_aggregate_( + tensor_array_multiple_writes_aggregate) { CHECK(kind_ != kInvalid); for (const string& gradient : tensor_array_gradients) { tensor_array_gradients_[gradient].reset(new XlaResource( /*kind=*/kTensorArray, /*arg_num=*/-1, - /*name=*/strings::StrCat("TensorArrayGrad: ", name_), type_, shape_, - xla::XlaOp(), tensor_array_size_, /*tensor_array_gradients=*/{})); + /*name=*/absl::StrCat("TensorArrayGrad: ", name_), type_, shape_, + xla::XlaOp(), tensor_array_size_, /*tensor_array_gradients=*/{}, + /*tensor_array_multiple_writes_aggregate=*/true)); } } @@ -135,9 +139,10 @@ Status XlaResource::GetOrCreateTensorArrayGradient(const string& source, xla::Broadcast(XlaHelpers::Zero(builder, type_), ta_shape.dim_sizes()); gradient.reset( new XlaResource(/*kind=*/kTensorArray, /*arg_num=*/-1, - /*name=*/strings::StrCat("TensorArrayGrad: ", name_), + /*name=*/absl::StrCat("TensorArrayGrad: ", name_), type_, shape_, gradient_value, tensor_array_size_, - /*tensor_array_gradients=*/{})); + /*tensor_array_gradients=*/{}, + /*tensor_array_multiple_writes_aggregate=*/true)); } *gradient_out = gradient.get(); return Status::OK(); diff --git a/tensorflow/compiler/tf2xla/xla_resource.h b/tensorflow/compiler/tf2xla/xla_resource.h index 2438490be13809b9f3571a362900b44cb838e76b..aa9ce1b171f11ea0de4db0123098729c1c97f93a 100644 --- a/tensorflow/compiler/tf2xla/xla_resource.h +++ b/tensorflow/compiler/tf2xla/xla_resource.h @@ -39,7 +39,8 @@ class XlaResource { XlaResource(Kind kind, int arg_num, string name, DataType type, TensorShape shape, const xla::XlaOp& initial_value, int64 tensor_array_size, - const std::set& tensor_array_gradients); + const std::set& tensor_array_gradients, + bool tensor_array_multiple_writes_aggregate); XlaResource(const XlaResource&) = delete; XlaResource(XlaResource&&) = delete; @@ -113,6 +114,8 @@ class XlaResource { const xla::XlaOp& pack, xla::XlaBuilder* builder); // TensorArray and Stack specific fields + // TODO(phawkins): refactor this code to use subclasses, rather than putting + // kind-specific fields in XlaResource. // 'tensor_array_size' stores the expected size of the TensorArray or Stack. // We need to store this since sometimes TensorArrays must be initialized @@ -121,6 +124,10 @@ class XlaResource { int64 tensor_array_size() const { return tensor_array_size_; } void set_tensor_array_size(int64 size) { tensor_array_size_ = size; } + bool tensor_array_multiple_writes_aggregate() const { + return tensor_array_multiple_writes_aggregate_; + } + // 'tensor_array_gradient' is a map from TensorArrayGradV3 'source' attributes // to an XlaResource containing the gradient TensorArrays. We store a pointer // here since there should only be one gradient TensorArray per 'source' @@ -143,6 +150,7 @@ class XlaResource { xla::XlaOp initial_value_; int64 tensor_array_size_ = -1; + bool tensor_array_multiple_writes_aggregate_ = false; std::map> tensor_array_gradients_; }; diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index 26bd1ac4f7e316208bcf0d085128c2242787d3df..cc7390c6e60375b4c31c38f9f7dee25730f8f51e 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -175,6 +175,8 @@ cc_library( "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", ], ) @@ -191,6 +193,7 @@ cc_library( ":types", ":util", "//tensorflow/core:lib", + "@com_google_absl//absl/synchronization", ], ) @@ -242,9 +245,11 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:regexp_internal", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", ], ) @@ -305,6 +310,8 @@ cc_library( "//tensorflow/core:lib", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", ], ) @@ -347,6 +354,7 @@ cc_library( "//tensorflow/core:lib", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", ], ) @@ -366,6 +374,7 @@ cc_library( ":util", "//tensorflow/core:lib", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", ], ) @@ -378,6 +387,7 @@ cc_library( ":util", "//tensorflow/core:lib", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", ], ) @@ -400,6 +410,7 @@ cc_library( ":types", "//tensorflow/core:lib", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", ], ) @@ -468,6 +479,7 @@ cc_library( ":types", "//tensorflow/core:lib", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", ], ) @@ -479,6 +491,7 @@ tf_cc_test( ":test", "//tensorflow/core:lib", "//tensorflow/core:test_main", + "@com_google_absl//absl/types:span", ], ) @@ -506,6 +519,7 @@ cc_library( ":util", ":xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/base", "@com_google_absl//absl/memory", ], ) @@ -573,6 +587,7 @@ cc_library( ":xla_data_proto", "//tensorflow/core:lib", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", ], ) @@ -603,6 +618,7 @@ cc_library( "//tensorflow/core:lib_internal", "@com_google_absl//absl/memory", "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", ], ) @@ -644,6 +660,7 @@ cc_library( ":xla_data_proto", "//tensorflow/core:lib", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", ], ) @@ -668,6 +685,7 @@ cc_library( "//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_matmul", "//tensorflow/core:lib", "@com_google_absl//absl/memory", + "@com_google_absl//absl/types:span", ], ) @@ -698,8 +716,8 @@ cc_library( ":array2d", ":shape_util", ":xla_data_proto", - "//tensorflow/core:lib", "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/types:span", ], ) diff --git a/tensorflow/compiler/xla/array.h b/tensorflow/compiler/xla/array.h index c8e483712efb48e49135f8775ef079497f68776f..58cc1575858201b4508d7340cb47e59c4f4c5783 100644 --- a/tensorflow/compiler/xla/array.h +++ b/tensorflow/compiler/xla/array.h @@ -29,10 +29,10 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/bits.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" @@ -97,12 +97,11 @@ class Array { using value_type = T; // Creates a new array with the specified dimensions. - explicit Array(tensorflow::gtl::ArraySlice sizes) - : Array(sizes, T()) {} + explicit Array(absl::Span sizes) : Array(sizes, T()) {} // Creates a new array with the specified dimensions and specified value for // every cell. - Array(tensorflow::gtl::ArraySlice sizes, T value) + Array(absl::Span sizes, T value) : sizes_(sizes.begin(), sizes.end()), values_(new T[num_elements()]) { Fill(value); } @@ -301,7 +300,7 @@ class Array { // Invokes a callback with the (indices, value_ptr) for each cell in the // array. - void Each(std::function, T*)> f) { + void Each(std::function, T*)> f) { std::vector index(sizes_.size()); for (int64 i = 0; i < num_elements(); ++i, next_index(&index)) { f(index, &values_[i]); @@ -309,8 +308,7 @@ class Array { } // Invokes a callback with the (indices, value) for each cell in the array. - void Each( - std::function, T)> f) const { + void Each(std::function, T)> f) const { std::vector index(sizes_.size()); for (int64 i = 0; i < num_elements(); ++i, next_index(&index)) { f(index, values_[i]); @@ -320,8 +318,7 @@ class Array { // Invokes a callback with the (indices, value_ptr) for each cell in the // array. If a callback returns a non-OK status, returns that else returns // Status::OK(). - Status EachStatus( - std::function, T*)> f) { + Status EachStatus(std::function, T*)> f) { std::vector index(sizes_.size()); for (int64 i = 0; i < num_elements(); ++i, next_index(&index)) { Status s = f(index, &values_[i]); @@ -335,8 +332,7 @@ class Array { // Invokes a callback with the (indices, value) for each cell in the array. // If a callback returns a non-OK status, returns that else returns // Status::OK(). - Status EachStatus( - std::function, T)> f) const { + Status EachStatus(std::function, T)> f) const { std::vector index(sizes_.size()); for (int64 i = 0; i < num_elements(); ++i, next_index(&index)) { Status s = f(index, values_[i]); @@ -377,13 +373,13 @@ class Array { // Returns the value at the cell specified by the indexes. The number of // arguments have to match with the number of dimensions for the array. - const T& operator()(tensorflow::gtl::ArraySlice indexes) const { + const T& operator()(absl::Span indexes) const { return values_[calculate_index(indexes)]; } // Returns the value at the cell specified by the indexes. The number of // arguments have to match with the number of dimensions for the array. - T& operator()(tensorflow::gtl::ArraySlice indexes) { + T& operator()(absl::Span indexes) { return values_[calculate_index(indexes)]; } @@ -438,8 +434,8 @@ class Array { bool operator!=(const Array& other) const { return !(*this == other); } // Performs the equivalent of a slice operation on this array. - Array Slice(tensorflow::gtl::ArraySlice starts, - tensorflow::gtl::ArraySlice limits) const { + Array Slice(absl::Span starts, + absl::Span limits) const { CHECK_EQ(starts.size(), num_dimensions()); CHECK_EQ(limits.size(), num_dimensions()); @@ -464,7 +460,7 @@ class Array { // Performs the equivalent of a DynamicUpdateSlice in-place on this array. void UpdateSlice(const Array& from, - tensorflow::gtl::ArraySlice start_indices) { + absl::Span start_indices) { CHECK_EQ(from.num_dimensions(), num_dimensions()); std::vector limit_indices; std::transform(start_indices.begin(), start_indices.end(), @@ -484,7 +480,7 @@ class Array { // Performs an in-place reshape, modifying the dimensions but not the // underlying data. - void Reshape(tensorflow::gtl::ArraySlice new_dimensions) { + void Reshape(absl::Span new_dimensions) { int64 old_num_elements = num_elements(); sizes_ = std::vector(new_dimensions.begin(), new_dimensions.end()); CHECK_EQ(num_elements(), old_num_elements); diff --git a/tensorflow/compiler/xla/array4d.h b/tensorflow/compiler/xla/array4d.h index 14e7bf1814120beb0247c4b130d72201785e58a7..e23d317baf9aca7b3705a93d6be952fb9a17762b 100644 --- a/tensorflow/compiler/xla/array4d.h +++ b/tensorflow/compiler/xla/array4d.h @@ -27,11 +27,10 @@ limitations under the License. #include #include "absl/strings/str_cat.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/array.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" diff --git a/tensorflow/compiler/xla/array4d_test.cc b/tensorflow/compiler/xla/array4d_test.cc index 927733ea1eab43feff643c35535cc6d9ea59ba5a..918872a7a03a022c72d22dfb8f0da9e9d3820e41 100644 --- a/tensorflow/compiler/xla/array4d_test.cc +++ b/tensorflow/compiler/xla/array4d_test.cc @@ -18,8 +18,8 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/core/lib/gtl/array_slice.h" namespace xla { namespace { @@ -27,8 +27,7 @@ namespace { // Given an Array4D and a 4-tuple index, computes the linear index into the // array idx represents. template -int64 Array4DLinearIndex(const Array4D& arr, - tensorflow::gtl::ArraySlice idx) { +int64 Array4DLinearIndex(const Array4D& arr, absl::Span idx) { EXPECT_EQ(4, idx.size()); return (idx[3] + idx[2] * arr.n4() + idx[1] * arr.n3() * arr.n4() + idx[0] * arr.n2() * arr.n3() * arr.n4()); @@ -51,9 +50,8 @@ TEST(Array4dTest, FillCtor) { EXPECT_EQ(fullof7.n3(), 4); EXPECT_EQ(fullof7.n4(), 5); - fullof7.Each([](tensorflow::gtl::ArraySlice idx, int* cell) { - EXPECT_EQ(*cell, 7); - }); + fullof7.Each( + [](absl::Span idx, int* cell) { EXPECT_EQ(*cell, 7); }); } TEST(Array4dTest, ContainerCtor) { @@ -69,7 +67,7 @@ TEST(Array4dTest, ContainerCtor) { EXPECT_EQ(arr.n3(), 4); EXPECT_EQ(arr.n4(), 5); - arr.Each([&arr](tensorflow::gtl::ArraySlice idx, int* cell) { + arr.Each([&arr](absl::Span idx, int* cell) { EXPECT_EQ(*cell, Array4DLinearIndex(arr, idx)); }); } @@ -129,21 +127,19 @@ TEST(Array3dTest, InitializerListCtorHalf) { TEST(Array4dTest, Fill) { Array4D fullof7(2, 3, 4, 5, 7); - fullof7.Each([](tensorflow::gtl::ArraySlice idx, int* cell) { - EXPECT_EQ(*cell, 7); - }); + fullof7.Each( + [](absl::Span idx, int* cell) { EXPECT_EQ(*cell, 7); }); fullof7.Fill(11); - fullof7.Each([](tensorflow::gtl::ArraySlice idx, int* cell) { - EXPECT_EQ(*cell, 11); - }); + fullof7.Each( + [](absl::Span idx, int* cell) { EXPECT_EQ(*cell, 11); }); } TEST(Array4dTest, FillWithMultiples) { Array4D arr(2, 3, 4, 5); arr.FillWithMultiples(2.0f); - arr.Each([&arr](tensorflow::gtl::ArraySlice idx, float* cell) { + arr.Each([&arr](absl::Span idx, float* cell) { EXPECT_EQ(*cell, 2.0f * Array4DLinearIndex(arr, idx)); }); } diff --git a/tensorflow/compiler/xla/array_test.cc b/tensorflow/compiler/xla/array_test.cc index e8356c9832d34135f5ffb1a5c7a9d6db6db3a051..2d0ac98bd4ee27004295c4189cb190bb2c9739c9 100644 --- a/tensorflow/compiler/xla/array_test.cc +++ b/tensorflow/compiler/xla/array_test.cc @@ -163,7 +163,7 @@ TEST(ArrayTest, Each) { arr.FillWithMultiples(1); int64 each_count = 0, each_sum = 0; - arr.Each([&](tensorflow::gtl::ArraySlice idx, int cell) { + arr.Each([&](absl::Span idx, int cell) { int64 lin_idx = idx[0] * 12 + idx[1] * 4 + idx[2]; EXPECT_EQ(lin_idx, cell); each_count++; diff --git a/tensorflow/compiler/xla/client/BUILD b/tensorflow/compiler/xla/client/BUILD index 9ad8ee20141c46c02e7a5b50c62b884f1cda79c8..dc097f3696e22d75d7dc72ec4877a9c8b5dda059 100644 --- a/tensorflow/compiler/xla/client/BUILD +++ b/tensorflow/compiler/xla/client/BUILD @@ -45,6 +45,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/types:span", ], ) @@ -78,6 +79,7 @@ cc_library( "//tensorflow/core:lib", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", ], ) @@ -92,6 +94,7 @@ cc_library( "//tensorflow/compiler/xla/service:device_memory_allocator", "//tensorflow/core:lib", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:optional", ], ) @@ -117,9 +120,9 @@ cc_library( "//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/compiler/xla/service:source_map_util", "//tensorflow/compiler/xla/service:stream_pool", - "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "@com_google_absl//absl/memory", + "@com_google_absl//absl/types:span", "@llvm//:support", ], ) @@ -217,8 +220,11 @@ cc_library( "//tensorflow/compiler/xla/service:shape_inference", "//tensorflow/core:lib", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", ], ) diff --git a/tensorflow/compiler/xla/client/client.cc b/tensorflow/compiler/xla/client/client.cc index 1fdf8f6260d3f00db43647a4d4de2842d69bf833..5dde5b432f136c16d4e3795569499ee5de709763 100644 --- a/tensorflow/compiler/xla/client/client.cc +++ b/tensorflow/compiler/xla/client/client.cc @@ -37,8 +37,8 @@ Client::Client(ServiceInterface* stub) : stub_(stub) {} Client::~Client() = default; -StatusOr> Client::Transfer( - const GlobalData& data, const Shape* shape_with_layout) { +StatusOr Client::Transfer(const GlobalData& data, + const Shape* shape_with_layout) { TransferToClientRequest request; *request.mutable_data() = data.handle(); if (shape_with_layout != nullptr) { @@ -114,7 +114,7 @@ Status Client::TransferToInfeed(const LiteralSlice& literal, int64 replica_id, return Status::OK(); } -StatusOr> Client::TransferFromOutfeed( +StatusOr Client::TransferFromOutfeed( const Shape* shape_with_layout, int64 replica_id, const DeviceHandle* device_handle) { TransferFromOutfeedRequest request; @@ -162,9 +162,8 @@ Status Client::ResetDevice() { return Status::OK(); } -StatusOr> Client::ExecuteAndTransfer( - const XlaComputation& computation, - tensorflow::gtl::ArraySlice arguments, +StatusOr Client::ExecuteAndTransfer( + const XlaComputation& computation, absl::Span arguments, const ExecutionOptions* execution_options, ExecutionProfile* execution_profile) { TF_ASSIGN_OR_RETURN( @@ -178,8 +177,8 @@ StatusOr> Client::ExecuteAndTransfer( return Transfer(*data, shape_with_output_layout); } -StatusOr> Client::ComputeConstant( - const XlaComputation& computation, const Layout* output_layout) const { +StatusOr Client::ComputeConstant(const XlaComputation& computation, + const Layout* output_layout) const { ComputeConstantGraphRequest request; *request.mutable_computation() = computation.proto(); if (output_layout != nullptr) { @@ -212,8 +211,7 @@ StatusOr Client::LoadSnapshot(const HloSnapshot& module) { } StatusOr> Client::Execute( - const XlaComputation& computation, - tensorflow::gtl::ArraySlice arguments, + const XlaComputation& computation, absl::Span arguments, const ExecutionOptions* execution_options, ExecutionProfile* execution_profile) { ExecuteGraphRequest request; @@ -252,7 +250,7 @@ StatusOr> Client::Execute( } StatusOr>> Client::ExecuteParallel( - tensorflow::gtl::ArraySlice computations) { + absl::Span computations) { ExecuteGraphParallelRequest request; for (const XlaComputationInstance& computation : computations) { diff --git a/tensorflow/compiler/xla/client/client.h b/tensorflow/compiler/xla/client/client.h index be50cebfcc0e3c19002635dbd280b14048aa0c93..6f4d33c469f1f885cfeef546e3981dc3417ef71f 100644 --- a/tensorflow/compiler/xla/client/client.h +++ b/tensorflow/compiler/xla/client/client.h @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/literal.h" @@ -28,7 +29,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla.pb.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/macros.h" namespace xla { @@ -53,7 +53,7 @@ class Client { // will be filled with profile data from the execution. StatusOr> Execute( const XlaComputation& computation, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, const ExecutionOptions* execution_options = nullptr, ExecutionProfile* execution_profile = nullptr); @@ -82,7 +82,7 @@ class Client { // from each computation. // StatusOr>> ExecuteParallel( - tensorflow::gtl::ArraySlice computations); + absl::Span computations); // Requests device_count device handles available on the target. The returned // device handles are used to specify the devices to execute the computations @@ -96,8 +96,8 @@ class Client { // // If shape_with_layout is not nullptr, it points to a shape whose layout will // be the layout of the returned literal. - StatusOr> Transfer( - const GlobalData& data, const Shape* shape_with_layout = nullptr); + StatusOr Transfer(const GlobalData& data, + const Shape* shape_with_layout = nullptr); // Transfer the given literal to the server. This allocates memory on the // device and copies the literal's contents over. Returns a global data handle @@ -122,7 +122,7 @@ class Client { // device_handle and replica_id together specify a particular device; a device // assigned for the given replica_id among the replicas that the given device // handle belongs to. - StatusOr> TransferFromOutfeed( + StatusOr TransferFromOutfeed( const Shape* shape_with_layout, int64 replica_id = 0, const DeviceHandle* device_handle = nullptr); @@ -132,9 +132,9 @@ class Client { // Executes the computation with the given arguments and transfers the result // to the client as a literal. Parameters are defined the same as for // Execute() and Transfer(). - StatusOr> ExecuteAndTransfer( + StatusOr ExecuteAndTransfer( const XlaComputation& computation, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, const ExecutionOptions* execution_options = nullptr, ExecutionProfile* execution_profile = nullptr); @@ -153,7 +153,7 @@ class Client { // // If output_layout is non-null, then the output of the computation will be // stored using that layout. - StatusOr> ComputeConstant( + StatusOr ComputeConstant( const XlaComputation& computation, const Layout* output_layout = nullptr) const; diff --git a/tensorflow/compiler/xla/client/compile_only_client.cc b/tensorflow/compiler/xla/client/compile_only_client.cc index 040344c9a65de122a21831b0eb79504ab4401772..a6c58cb17571b63cd0f45d0d95376a02bc4a72e2 100644 --- a/tensorflow/compiler/xla/client/compile_only_client.cc +++ b/tensorflow/compiler/xla/client/compile_only_client.cc @@ -23,7 +23,7 @@ namespace xla { StatusOr>> CompileOnlyClient::CompileAheadOfTime( - const tensorflow::gtl::ArraySlice computations, + const absl::Span computations, const AotCompilationOptions& options, std::unique_ptr* metadata) { std::vector service_instances; diff --git a/tensorflow/compiler/xla/client/compile_only_client.h b/tensorflow/compiler/xla/client/compile_only_client.h index d0c83cbfccb99755f8f5b7fa2e179f25fb73d3d1..9e3ed23734941d98d622c38028cd44d48d3e620a 100644 --- a/tensorflow/compiler/xla/client/compile_only_client.h +++ b/tensorflow/compiler/xla/client/compile_only_client.h @@ -52,7 +52,7 @@ class CompileOnlyClient : public Client { // code. |metadata|, if provided, is populated during compilation. StatusOr>> CompileAheadOfTime( - const tensorflow::gtl::ArraySlice computations, + const absl::Span computations, const AotCompilationOptions& options, std::unique_ptr* metadata = nullptr); diff --git a/tensorflow/compiler/xla/client/executable_build_options.cc b/tensorflow/compiler/xla/client/executable_build_options.cc index 5a73408db5fd0c75fe9bc588f4800b4ac965d009..0f1745366b7c33e573aff2e66d85431b01488c49 100644 --- a/tensorflow/compiler/xla/client/executable_build_options.cc +++ b/tensorflow/compiler/xla/client/executable_build_options.cc @@ -15,8 +15,8 @@ limitations under the License. #include "tensorflow/compiler/xla/client/executable_build_options.h" +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/core/lib/strings/stringprintf.h" namespace xla { @@ -59,10 +59,10 @@ string ExecutableBuildOptions::ToString() const { if (generate_hlo_graph_.has_value()) { generate_hlo_graph = generate_hlo_graph_.value(); } - return tensorflow::strings::Printf( + return absl::StrFormat( "ExecutableBuildOptions{device_ordinal=%d, result_layout=%s, " "generate_hlo_graph=%s}", - device_ordinal_, result_layout.c_str(), generate_hlo_graph.c_str()); + device_ordinal_, result_layout, generate_hlo_graph); } ExecutableBuildOptions& ExecutableBuildOptions::set_generate_hlo_graph( diff --git a/tensorflow/compiler/xla/client/executable_build_options.h b/tensorflow/compiler/xla/client/executable_build_options.h index 888d2f28ebb2cfc73a58ba07d58d10405fb76832..93334db88bc24f2ffbf3c7a57ee45ef238286739 100644 --- a/tensorflow/compiler/xla/client/executable_build_options.h +++ b/tensorflow/compiler/xla/client/executable_build_options.h @@ -86,7 +86,7 @@ class ExecutableBuildOptions { void add_disabled_hlo_pass(absl::string_view pass_name) { disabled_hlo_passes_.push_back(std::string(pass_name)); } - const tensorflow::gtl::ArraySlice disabled_hlo_passes() const { + const absl::Span disabled_hlo_passes() const { return disabled_hlo_passes_; } diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD index 8736f18dcfa678f35ba9c749d373d2d4ad6a9bd6..a18c94c4e695a6cdcb9dcc60b64b617cecd276d8 100644 --- a/tensorflow/compiler/xla/client/lib/BUILD +++ b/tensorflow/compiler/xla/client/lib/BUILD @@ -113,7 +113,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/core:lib", + "@com_google_absl//absl/types:span", ], ) diff --git a/tensorflow/compiler/xla/client/lib/constants.cc b/tensorflow/compiler/xla/client/lib/constants.cc index 031d62e4ffef188082303a28866bbc72a154e9b1..1ada7b4a964ccf7ca400b937abbe425bef083468 100644 --- a/tensorflow/compiler/xla/client/lib/constants.cc +++ b/tensorflow/compiler/xla/client/lib/constants.cc @@ -56,7 +56,7 @@ XlaOp Epsilon(XlaBuilder* builder, PrimitiveType type) { std::numeric_limits::epsilon()); default: return builder->ReportError(InvalidArgument( - "Invalid type for Epsilon (%s).", PrimitiveType_Name(type).c_str())); + "Invalid type for Epsilon (%s).", PrimitiveType_Name(type))); } } diff --git a/tensorflow/compiler/xla/client/lib/constants.h b/tensorflow/compiler/xla/client/lib/constants.h index 0c8a9b8cc02ba0c1ebdf6a060d4b99262dceb178..81624614c1e3599dfe116eb61d9e2edcd5230684 100644 --- a/tensorflow/compiler/xla/client/lib/constants.h +++ b/tensorflow/compiler/xla/client/lib/constants.h @@ -37,13 +37,13 @@ XlaOp ConstantR0WithType(XlaBuilder* builder, PrimitiveType type, T value) { primitive_util::IsComplexType(type))) { return builder->ReportError(InvalidArgument( "Invalid cast from floating point type to %s in ConstantR0WithType.", - PrimitiveType_Name(type).c_str())); + PrimitiveType_Name(type))); } if (std::is_same::value && !primitive_util::IsComplexType(type)) { return builder->ReportError(InvalidArgument( "Invalid cast from complex type to %s in ConstantR0WithType.", - PrimitiveType_Name(type).c_str())); + PrimitiveType_Name(type))); } switch (type) { case F16: @@ -71,7 +71,7 @@ XlaOp ConstantR0WithType(XlaBuilder* builder, PrimitiveType type, T value) { default: return builder->ReportError( InvalidArgument("Invalid type for ConstantR0WithType (%s).", - PrimitiveType_Name(type).c_str())); + PrimitiveType_Name(type))); } } diff --git a/tensorflow/compiler/xla/client/lib/conv_grad_size_util.h b/tensorflow/compiler/xla/client/lib/conv_grad_size_util.h index c18087ce6b6addde62523a2d556e5f8146aa5dd1..0ad01728e6e828240b9ac4b948777e5d970d09e0 100644 --- a/tensorflow/compiler/xla/client/lib/conv_grad_size_util.h +++ b/tensorflow/compiler/xla/client/lib/conv_grad_size_util.h @@ -17,7 +17,6 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_CONV_GRAD_SIZE_UTIL_H_ #include "tensorflow/compiler/xla/client/padding.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/platform/types.h" namespace xla { diff --git a/tensorflow/compiler/xla/client/lib/math.cc b/tensorflow/compiler/xla/client/lib/math.cc index e569610b85578769750216d18151e635d475db37..d3d7edb42a38595bbf9fdb36e0dd946ae5df51f9 100644 --- a/tensorflow/compiler/xla/client/lib/math.cc +++ b/tensorflow/compiler/xla/client/lib/math.cc @@ -69,8 +69,7 @@ std::array kErfUCoefficient = { // Evaluate the polynomial given coefficients and `x`. // N.B. Coefficients should be supplied in decreasing order. -XlaOp EvaluatePolynomial(XlaOp x, - tensorflow::gtl::ArraySlice coefficients) { +XlaOp EvaluatePolynomial(XlaOp x, absl::Span coefficients) { XlaOp poly = ScalarLike(x, 0.0); for (float c : coefficients) { poly = poly * x + ScalarLike(x, c); diff --git a/tensorflow/compiler/xla/client/lib/math.h b/tensorflow/compiler/xla/client/lib/math.h index 13db2325569cf2e25e3ff1200adf4b2544dc2f73..a6cafd42077367bf23ffa1f45eab31c01dc31b16 100644 --- a/tensorflow/compiler/xla/client/lib/math.h +++ b/tensorflow/compiler/xla/client/lib/math.h @@ -34,8 +34,7 @@ XlaOp Reciprocal(XlaOp operand); // Evaluates a polynomial given coefficients and `x`. // N.B. Coefficients should be supplied in decreasing order. -XlaOp EvaluatePolynomial(XlaOp x, - tensorflow::gtl::ArraySlice coefficients); +XlaOp EvaluatePolynomial(XlaOp x, absl::Span coefficients); // Computes an approximation of the error function complement (1 - erf(x)). XlaOp Erfc(XlaOp x); diff --git a/tensorflow/compiler/xla/client/lib/numeric.cc b/tensorflow/compiler/xla/client/lib/numeric.cc index 1c91237ae1574f92cda78c9bddc6f4ac1d68f47c..377654220b5df4487e9e194361473d54ff46a54e 100644 --- a/tensorflow/compiler/xla/client/lib/numeric.cc +++ b/tensorflow/compiler/xla/client/lib/numeric.cc @@ -16,61 +16,13 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/lib/numeric.h" -#include "tensorflow/core/lib/gtl/array_slice.h" namespace xla { -namespace { - -template -XlaOp MakeIota(XlaBuilder* builder, int64 size) { - std::vector values(size); - for (int64 i = 0; i < size; ++i) { - values[i] = static_cast(i); - } - return ConstantR1(builder, values); -} - -} // namespace - -XlaOp Iota(XlaBuilder* builder, PrimitiveType type, int64 size) { - switch (type) { - case S8: - return MakeIota(builder, size); - case S16: - return MakeIota(builder, size); - case S32: - return MakeIota(builder, size); - case S64: - return MakeIota(builder, size); - case U8: - return MakeIota(builder, size); - case U16: - return MakeIota(builder, size); - case U32: - return MakeIota(builder, size); - case U64: - return MakeIota(builder, size); - case BF16: - return MakeIota(builder, size); - case F16: - return MakeIota(builder, size); - case F32: - return MakeIota(builder, size); - case F64: - return MakeIota(builder, size); - case C64: - return MakeIota(builder, size); - default: - return builder->ReportError( - InvalidArgument("Unimplemented type for Iota: %s.", - PrimitiveType_Name(type).c_str())); - } -} - XlaOp IdentityMatrix(XlaBuilder* builder, PrimitiveType type, int64 m, int64 n) { auto a = Iota(builder, type, m); @@ -87,8 +39,8 @@ XlaOp GetMatrixDiagonal(XlaOp x) { TF_RET_CHECK(n_dims >= 2); const int64 m = shape.dimensions(n_dims - 2); const int64 n = shape.dimensions(n_dims - 1); - tensorflow::gtl::ArraySlice major_dims( - AsInt64Slice(shape.dimensions()), /*pos=*/0, /*len=*/n_dims - 2); + absl::Span major_dims = + AsInt64Slice(shape.dimensions()).subspan(/*pos=*/0, /*len=*/n_dims - 2); auto a = Iota(builder, U32, n); auto b = Iota(builder, U32, m); auto indicator = Eq(b, Broadcast(a, {m}), /*broadcast_dimensions=*/{0}); @@ -114,8 +66,8 @@ XlaOp Triangle(XlaOp x, bool lower) { TF_RET_CHECK(n_dims >= 2); const int64 m = shape.dimensions(n_dims - 2); const int64 n = shape.dimensions(n_dims - 1); - tensorflow::gtl::ArraySlice major_dims( - AsInt64Slice(shape.dimensions()), /*pos=*/0, /*len=*/n_dims - 2); + absl::Span major_dims = + AsInt64Slice(shape.dimensions()).subspan(/*pos=*/0, /*len=*/n_dims - 2); auto a = Iota(builder, U32, n); auto b = Iota(builder, U32, m); xla::XlaOp indicator; diff --git a/tensorflow/compiler/xla/client/lib/numeric_test.cc b/tensorflow/compiler/xla/client/lib/numeric_test.cc index 8a96ec68d2dca8485215258b1f6731b934e6f2a8..7d6aedd49462bd4f075f90d0b0f85c40f1191aa1 100644 --- a/tensorflow/compiler/xla/client/lib/numeric_test.cc +++ b/tensorflow/compiler/xla/client/lib/numeric_test.cc @@ -30,16 +30,6 @@ class NumericTest : public ClientLibraryTestBase { void TestMatrixDiagonal(); }; -// TODO(b/64798317): Delete this test case once xla::IotaGen is converted to -// xla::Iota. This test is already implemented for xla::IotaGen in -// xla/tests/iota_test.cc. -XLA_TEST_F(NumericTest, Iota) { - XlaBuilder builder(TestName()); - Iota(&builder, S32, 10); - - ComputeAndCompareR1(&builder, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}, {}); -} - XLA_TEST_F(NumericTest, Triangle) { XlaBuilder builder(TestName()); Array3D input(2, 3, 4); diff --git a/tensorflow/compiler/xla/client/lib/pooling.cc b/tensorflow/compiler/xla/client/lib/pooling.cc index 3ae9ae36f654a8f5026ac3a37976dc97aca357ac..1979c867a4c3be438f8b997c566799fe84b43053 100644 --- a/tensorflow/compiler/xla/client/lib/pooling.cc +++ b/tensorflow/compiler/xla/client/lib/pooling.cc @@ -26,11 +26,9 @@ namespace { // element of an image by the count of elements that contributed to that // element during pooling. XlaOp AvgPoolDivideByCountWithGeneralPadding( - XlaOp sums, PrimitiveType dtype, - tensorflow::gtl::ArraySlice input_shape, - tensorflow::gtl::ArraySlice> spatial_padding, - tensorflow::gtl::ArraySlice ksize, - tensorflow::gtl::ArraySlice stride, + XlaOp sums, PrimitiveType dtype, absl::Span input_shape, + absl::Span> spatial_padding, + absl::Span ksize, absl::Span stride, const TensorFormat& data_format) { // The padding shouldn't be included in the counts. We use another // ReduceWindow to find the right counts. @@ -73,8 +71,8 @@ XlaOp AvgPoolDivideByCountWithGeneralPadding( // Sums all elements in the window specified by 'kernel_size' and 'stride'. XlaOp ComputeSums(XlaOp operand, XlaOp init_value, - tensorflow::gtl::ArraySlice kernel_size, - tensorflow::gtl::ArraySlice stride, + absl::Span kernel_size, + absl::Span stride, const TensorFormat& data_format) { XlaBuilder* b = operand.builder(); return b->ReportErrorOrReturn([&]() -> StatusOr { @@ -89,8 +87,8 @@ XlaOp ComputeSums(XlaOp operand, XlaOp init_value, // Creates a padding configuration out of spatial padding values. PaddingConfig MakeSpatialPaddingConfig( - tensorflow::gtl::ArraySlice> spatial_padding, - int num_spatial_dims, tensorflow::gtl::ArraySlice stride, + absl::Span> spatial_padding, + int num_spatial_dims, absl::Span stride, const TensorFormat& data_format) { PaddingConfig padding_config; for (int i = 0; i < 2 + num_spatial_dims; ++i) { @@ -107,13 +105,12 @@ PaddingConfig MakeSpatialPaddingConfig( return padding_config; } -XlaOp AvgPoolDivideByCount( - XlaOp pooled, tensorflow::gtl::ArraySlice input_size, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - PrimitiveType dtype, const TensorFormat& data_format, - bool counts_include_padding) { +XlaOp AvgPoolDivideByCount(XlaOp pooled, absl::Span input_size, + absl::Span window_dimensions, + absl::Span window_strides, + absl::Span> padding, + PrimitiveType dtype, const TensorFormat& data_format, + bool counts_include_padding) { if (counts_include_padding) { // If counts include padding, all windows have the same number of elements // contributing to each average. Divide by the window size everywhere to get @@ -133,8 +130,8 @@ XlaOp AvgPoolDivideByCount( } // namespace -XlaOp MaxPool(XlaOp operand, tensorflow::gtl::ArraySlice kernel_size, - tensorflow::gtl::ArraySlice stride, Padding padding, +XlaOp MaxPool(XlaOp operand, absl::Span kernel_size, + absl::Span stride, Padding padding, const TensorFormat& data_format) { XlaBuilder* b = operand.builder(); return b->ReportErrorOrReturn([&]() -> StatusOr { @@ -147,9 +144,9 @@ XlaOp MaxPool(XlaOp operand, tensorflow::gtl::ArraySlice kernel_size, }); } -XlaOp AvgPool(XlaOp operand, tensorflow::gtl::ArraySlice kernel_size, - tensorflow::gtl::ArraySlice stride, - tensorflow::gtl::ArraySlice> padding, +XlaOp AvgPool(XlaOp operand, absl::Span kernel_size, + absl::Span stride, + absl::Span> padding, const TensorFormat& data_format, const bool counts_include_padding) { XlaBuilder* b = operand.builder(); @@ -173,9 +170,8 @@ XlaOp AvgPool(XlaOp operand, tensorflow::gtl::ArraySlice kernel_size, } std::vector> MakeSpatialPadding( - tensorflow::gtl::ArraySlice input_size, - tensorflow::gtl::ArraySlice kernel_size, - tensorflow::gtl::ArraySlice stride, Padding padding, + absl::Span input_size, absl::Span kernel_size, + absl::Span stride, Padding padding, const TensorFormat& data_format) { const int num_spatial_dims = kernel_size.size() - 2; std::vector input_spatial_dimensions; @@ -193,12 +189,12 @@ std::vector> MakeSpatialPadding( stride_spatial_dimensions, padding); } -XlaOp AvgPoolGrad( - XlaOp out_backprop, tensorflow::gtl::ArraySlice gradients_size, - tensorflow::gtl::ArraySlice kernel_size, - tensorflow::gtl::ArraySlice stride, - tensorflow::gtl::ArraySlice> spatial_padding, - const TensorFormat& data_format, const bool counts_include_padding) { +XlaOp AvgPoolGrad(XlaOp out_backprop, absl::Span gradients_size, + absl::Span kernel_size, + absl::Span stride, + absl::Span> spatial_padding, + const TensorFormat& data_format, + const bool counts_include_padding) { XlaBuilder* b = out_backprop.builder(); return b->ReportErrorOrReturn([&]() -> StatusOr { const int num_dims = kernel_size.size(); diff --git a/tensorflow/compiler/xla/client/lib/pooling.h b/tensorflow/compiler/xla/client/lib/pooling.h index 291c711a005eb7e7e544bb792eb09422491d5d69..5c0054857d072dc7f36e259a29b9b24fd70796ac 100644 --- a/tensorflow/compiler/xla/client/lib/pooling.h +++ b/tensorflow/compiler/xla/client/lib/pooling.h @@ -25,7 +25,7 @@ namespace xla { class TensorFormat { public: TensorFormat(int batch_dimension, int feature_dimension, - tensorflow::gtl::ArraySlice spatial_dimensions) + absl::Span spatial_dimensions) : batch_dimension_(batch_dimension), feature_dimension_(feature_dimension), spatial_dimensions_(spatial_dimensions.begin(), @@ -49,32 +49,31 @@ class TensorFormat { }; // Computes the max pool of 'operand'. -XlaOp MaxPool(XlaOp operand, tensorflow::gtl::ArraySlice kernel_size, - tensorflow::gtl::ArraySlice stride, Padding padding, +XlaOp MaxPool(XlaOp operand, absl::Span kernel_size, + absl::Span stride, Padding padding, const TensorFormat& data_format); // Computes the average pool of 'operand'. -XlaOp AvgPool(XlaOp operand, tensorflow::gtl::ArraySlice kernel_size, - tensorflow::gtl::ArraySlice stride, - tensorflow::gtl::ArraySlice> padding, +XlaOp AvgPool(XlaOp operand, absl::Span kernel_size, + absl::Span stride, + absl::Span> padding, const TensorFormat& data_format, const bool counts_include_padding); // Returns the list of low and high padding elements in each spatial dimension // for the given 'padding' specification. std::vector> MakeSpatialPadding( - tensorflow::gtl::ArraySlice input_size, - tensorflow::gtl::ArraySlice kernel_size, - tensorflow::gtl::ArraySlice stride, Padding padding, + absl::Span input_size, absl::Span kernel_size, + absl::Span stride, Padding padding, const TensorFormat& data_format); // Computes the average pool gradient. -XlaOp AvgPoolGrad( - XlaOp out_backprop, tensorflow::gtl::ArraySlice gradients_size, - tensorflow::gtl::ArraySlice kernel_size, - tensorflow::gtl::ArraySlice stride, - tensorflow::gtl::ArraySlice> spatial_padding, - const TensorFormat& data_format, const bool counts_include_padding); +XlaOp AvgPoolGrad(XlaOp out_backprop, absl::Span gradients_size, + absl::Span kernel_size, + absl::Span stride, + absl::Span> spatial_padding, + const TensorFormat& data_format, + const bool counts_include_padding); } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/pooling_test.cc b/tensorflow/compiler/xla/client/lib/pooling_test.cc index 18900479189c3afd131969687a973ea6061ffd9f..30adb9b1ad7fa03b40ce3802a2172680b60a9ad7 100644 --- a/tensorflow/compiler/xla/client/lib/pooling_test.cc +++ b/tensorflow/compiler/xla/client/lib/pooling_test.cc @@ -32,8 +32,8 @@ TensorFormat MakeNCHWFormat(int num_spatial_dims) { } std::vector> MakeGeneralPadding( - XlaOp input, tensorflow::gtl::ArraySlice kernel_size, - tensorflow::gtl::ArraySlice stride, Padding padding, + XlaOp input, absl::Span kernel_size, + absl::Span stride, Padding padding, const xla::TensorFormat& data_format) { XlaBuilder* b = input.builder(); Shape operand_shape = b->GetShape(input).ValueOrDie(); @@ -46,7 +46,7 @@ std::vector> MakeGeneralPadding( // Add singleton batch and feature dimensions to spatial dimensions, according // to 'data_format' specification. std::vector ExpandWithBatchAndFeatureDimensions( - tensorflow::gtl::ArraySlice spatial_dim_sizes, + absl::Span spatial_dim_sizes, const xla::TensorFormat& data_format) { const int num_spatial_dims = spatial_dim_sizes.size(); std::vector tensor_sizes(num_spatial_dims + 2, 1); diff --git a/tensorflow/compiler/xla/client/lib/sorting.cc b/tensorflow/compiler/xla/client/lib/sorting.cc index a904be259a3870a679b2c4699ec01e2a11b1ce46..0475fd9c94f6e390b5169cfe2cbba8eae28ddc18 100644 --- a/tensorflow/compiler/xla/client/lib/sorting.cc +++ b/tensorflow/compiler/xla/client/lib/sorting.cc @@ -29,7 +29,7 @@ XlaOp TopK(XlaOp input, int64 k) { auto input_dims = input_shape.dimensions(); std::vector broadcast_dims(input_dims.begin(), input_dims.end() - 1); XlaOp broadcast_s32 = Broadcast(iota_s32, broadcast_dims); - XlaOp sort_result = Sort(Neg(input), broadcast_s32); + XlaOp sort_result = Sort(Neg(input), {broadcast_s32}); std::vector start_indices(input_shape.dimensions_size(), 0); std::vector limit_indices(input_dims.begin(), input_dims.end()); limit_indices[last_dim] = k; diff --git a/tensorflow/compiler/xla/client/lib/testing.cc b/tensorflow/compiler/xla/client/lib/testing.cc index 6861521acc0db1d640666a6793b898a183ab6a17..a44681f586278bf03f3fb2b8c812936cbf3ad47b 100644 --- a/tensorflow/compiler/xla/client/lib/testing.cc +++ b/tensorflow/compiler/xla/client/lib/testing.cc @@ -76,7 +76,7 @@ std::unique_ptr MakeFakeDataViaDeviceOrDie(const Shape& shape, std::unique_ptr MakeFakeDataOrDie(const Shape& shape, Client* client) { if (DataSizeOfShape(shape) < (1LL << 20)) { - StatusOr> literal_status = MakeFakeLiteral(shape); + StatusOr literal_status = MakeFakeLiteral(shape); if (!literal_status.ok()) { // If we got an Unimplemented error, fall back to making the fake data via // an on-device computation. @@ -84,7 +84,7 @@ std::unique_ptr MakeFakeDataOrDie(const Shape& shape, tensorflow::error::UNIMPLEMENTED); return MakeFakeDataViaDeviceOrDie(shape, client); } - return client->TransferToServer(*literal_status.ValueOrDie()).ValueOrDie(); + return client->TransferToServer(literal_status.ValueOrDie()).ValueOrDie(); } // If the data is large, generate it on-device. @@ -93,17 +93,15 @@ std::unique_ptr MakeFakeDataOrDie(const Shape& shape, std::vector> MakeFakeArgumentsOrDie( const XlaComputation& computation, Client* client) { - CHECK(computation.proto().has_program_shape()) + CHECK(computation.proto().has_host_program_shape()) << "Computation should have progran shape."; - auto program_shape = computation.proto().program_shape(); - - // Create and run a program which produces a tuple with one element per - // parameter, then return the tuple's constituent buffers. - std::vector param_shapes(program_shape.parameters().begin(), - program_shape.parameters().end()); - auto fake_input_tuple = - MakeFakeDataOrDie(ShapeUtil::MakeTupleShape(param_shapes), client); - return client->DeconstructTuple(*fake_input_tuple).ValueOrDie(); + auto program_shape = computation.proto().host_program_shape(); + + std::vector> results; + for (const Shape& shape : program_shape.parameters()) { + results.push_back(MakeFakeDataOrDie(shape, client)); + } + return results; } } // namespace xla diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc index 1cd3e9b22f9cf3383cfcbc19c79acba0e5938190..f96b6c9c261a9686fb647e3da0dcc933cd1f70df 100644 --- a/tensorflow/compiler/xla/client/local_client.cc +++ b/tensorflow/compiler/xla/client/local_client.cc @@ -51,7 +51,7 @@ LocalExecutable::LocalExecutable(std::unique_ptr executable, } Status LocalExecutable::ValidateExecutionOptions( - const tensorflow::gtl::ArraySlice arguments, + const absl::Span arguments, const ExecutableRunOptions& run_options, const Backend& backend) { const ComputationLayout& computation_layout = executable_->module_config().entry_computation_layout(); @@ -59,7 +59,7 @@ Status LocalExecutable::ValidateExecutionOptions( // Check argument number, shapes, and layouts. if (arguments.size() != computation_layout.parameter_count()) { return InvalidArgument( - "invalid number of arguments for computation: expected %d, got %zu", + "invalid number of arguments for computation: expected %d, got %u", computation_layout.parameter_count(), arguments.size()); } for (int i = 0; i < arguments.size(); ++i) { @@ -71,9 +71,9 @@ Status LocalExecutable::ValidateExecutionOptions( "parameter " "%d: want %s, got %s", i, - ShapeUtil::HumanString(computation_layout.parameter_layout(i).shape()) - .c_str(), - ShapeUtil::HumanString(arguments[i]->on_host_shape()).c_str()); + ShapeUtil::HumanString( + computation_layout.parameter_layout(i).shape()), + ShapeUtil::HumanString(arguments[i]->on_host_shape())); } } @@ -88,8 +88,7 @@ Status LocalExecutable::ValidateExecutionOptions( if (stream_platform != backend_->platform()) { return InvalidArgument( "stream is for platform %s, but service targets platform %s", - stream_platform->Name().c_str(), - backend_->platform()->Name().c_str()); + stream_platform->Name(), backend_->platform()->Name()); } // Cannot specify device_ordinal with a stream. The stream determines these @@ -120,10 +119,10 @@ Status LocalExecutable::ValidateExecutionOptions( return InvalidArgument( "executable is built for device %s of type \"%s\"; cannot run it on " "device %s of type \"%s\"", - backend_->device_name(build_device_ordinal()).c_str(), - build_executor->GetDeviceDescription().name().c_str(), - backend_->device_name(run_device_ordinal).c_str(), - run_executor->GetDeviceDescription().name().c_str()); + backend_->device_name(build_device_ordinal()), + build_executor->GetDeviceDescription().name(), + backend_->device_name(run_device_ordinal), + run_executor->GetDeviceDescription().name()); } if (!run_options.allocator()) { @@ -133,15 +132,15 @@ Status LocalExecutable::ValidateExecutionOptions( if (run_options.allocator()->platform() != backend.platform()) { return InvalidArgument( "allocator platform (%s) does not match service platform (%s)", - run_options.allocator()->platform()->Name().c_str(), - backend.platform()->Name().c_str()); + run_options.allocator()->platform()->Name(), + backend.platform()->Name()); } return Status::OK(); } StatusOr LocalExecutable::Run( - const tensorflow::gtl::ArraySlice arguments, + const absl::Span arguments, ExecutableRunOptions run_options) { TF_RETURN_IF_ERROR( ValidateExecutionOptions(arguments, run_options, *backend_)); @@ -178,7 +177,7 @@ StatusOr LocalExecutable::Run( StatusOr LocalExecutable::ExecuteAndDump( const ServiceExecutableRunOptions* run_options, - const tensorflow::gtl::ArraySlice arguments) { + const absl::Span arguments) { executable_->hlo_snapshot()->set_execution_platform( backend_->platform()->Name()); TF_RETURN_IF_ERROR(RecordArguments(arguments, executable_->hlo_snapshot())); @@ -192,13 +191,12 @@ StatusOr LocalExecutable::ExecuteAndDump( } Status LocalExecutable::RecordArguments( - const tensorflow::gtl::ArraySlice arguments, + const absl::Span arguments, HloSnapshot* hlo_snapshot) { hlo_snapshot->clear_arguments(); for (const ShapedBuffer* argument : arguments) { - TF_ASSIGN_OR_RETURN(std::unique_ptr literal, - LiteralFromShapedBuffer(*argument)); - *hlo_snapshot->add_arguments() = literal->ToProto(); + TF_ASSIGN_OR_RETURN(Literal literal, LiteralFromShapedBuffer(*argument)); + *hlo_snapshot->add_arguments() = literal.ToProto(); } return Status::OK(); } @@ -206,13 +204,12 @@ Status LocalExecutable::RecordArguments( Status LocalExecutable::RecordResult(const ShapedBuffer* result, HloSnapshot* hlo_snapshot) { hlo_snapshot->clear_result(); - TF_ASSIGN_OR_RETURN(std::unique_ptr literal, - LiteralFromShapedBuffer(*result)); - *hlo_snapshot->mutable_result() = literal->ToProto(); + TF_ASSIGN_OR_RETURN(Literal literal, LiteralFromShapedBuffer(*result)); + *hlo_snapshot->mutable_result() = literal.ToProto(); return Status::OK(); } -StatusOr> LocalExecutable::LiteralFromShapedBuffer( +StatusOr LocalExecutable::LiteralFromShapedBuffer( const ShapedBuffer& shaped_buffer) { TF_ASSIGN_OR_RETURN(auto stream, backend_->BorrowStream(shaped_buffer.device_ordinal())); @@ -246,7 +243,7 @@ Backend* LocalClient::mutable_backend() { StatusOr> LocalClient::Compile( const XlaComputation& computation, - const tensorflow::gtl::ArraySlice argument_layouts, + const absl::Span argument_layouts, const ExecutableBuildOptions& options) { ExecutableBuildOptions updated_options = options; if (options.device_ordinal() == -1) { @@ -278,7 +275,7 @@ StatusOr LocalClient::LiteralToShapedBuffer( return std::move(scoped_buffer); } -StatusOr> LocalClient::ShapedBufferToLiteral( +StatusOr LocalClient::ShapedBufferToLiteral( const ShapedBuffer& shaped_buffer) { TF_ASSIGN_OR_RETURN(auto stream, mutable_backend()->BorrowStream( shaped_buffer.device_ordinal())); @@ -299,13 +296,13 @@ Status LocalClient::TransferToInfeedLocal(const Literal& literal, literal); } -StatusOr> LocalClient::TransferFromOutfeedLocal( - const Shape& shape, int device_ordinal) { +StatusOr LocalClient::TransferFromOutfeedLocal(const Shape& shape, + int device_ordinal) { TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor, backend().stream_executor(device_ordinal)); auto literal = Literal::CreateFromShape(shape); TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralFromOutfeed( - executor, shape, literal.get())); + executor, shape, &literal)); return std::move(literal); } diff --git a/tensorflow/compiler/xla/client/local_client.h b/tensorflow/compiler/xla/client/local_client.h index ae23809261757c637ab4aec036750c371ac60cdc..feb2f8ec9dab5bf13afdc866d10ccbe74f8edcb9 100644 --- a/tensorflow/compiler/xla/client/local_client.h +++ b/tensorflow/compiler/xla/client/local_client.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/executable_build_options.h" #include "tensorflow/compiler/xla/client/xla_computation.h" @@ -30,7 +31,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/shaped_buffer.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" namespace xla { @@ -40,7 +40,7 @@ class LocalExecutable { // Run the compiled computation with the given arguments and options and // return the result. StatusOr Run( - const tensorflow::gtl::ArraySlice arguments, + const absl::Span arguments, ExecutableRunOptions run_options); // Return the options used to build the executable. @@ -63,7 +63,7 @@ class LocalExecutable { // The given ExecutableRunOptions override any values from legacy_flags // (TF_XLA_FLAGS environment variable). Status ValidateExecutionOptions( - const tensorflow::gtl::ArraySlice arguments, + const absl::Span arguments, const ExecutableRunOptions& run_options, const Backend& backend); // Records the computation in a SessionModule proto with the arguments used to @@ -73,20 +73,18 @@ class LocalExecutable { // (TF_XLA_FLAGS environment variable). StatusOr ExecuteAndDump( const ServiceExecutableRunOptions* run_options, - const tensorflow::gtl::ArraySlice arguments); + const absl::Span arguments); // Records the arguments used to invoke the computation in a SessionModule // proto. - Status RecordArguments( - const tensorflow::gtl::ArraySlice arguments, - HloSnapshot* hlo_snapshot); + Status RecordArguments(const absl::Span arguments, + HloSnapshot* hlo_snapshot); // Records the result of the computation in a SessionModule proto. Status RecordResult(const ShapedBuffer* result, HloSnapshot* hlo_snapshot); // Returns a literal containing the contents of the given ShapedBuffer. - StatusOr> LiteralFromShapedBuffer( - const ShapedBuffer& shaped_buffer); + StatusOr LiteralFromShapedBuffer(const ShapedBuffer& shaped_buffer); // The ordinal of the device which this executable was compiled for. The // executable can run on all equivalent devices (as determined by @@ -120,7 +118,7 @@ class LocalClient : public Client { // (TF_XLA_FLAGS environment variable). StatusOr> Compile( const XlaComputation& computation, - const tensorflow::gtl::ArraySlice argument_layouts, + const absl::Span argument_layouts, const ExecutableBuildOptions& options); // Copy the literal data to the device with the given ordinal and return as a @@ -133,8 +131,7 @@ class LocalClient : public Client { // Copy the data from the device contained in the given ShapedBuffer and // return as a Literal. - StatusOr> ShapedBufferToLiteral( - const ShapedBuffer& shaped_buffer); + StatusOr ShapedBufferToLiteral(const ShapedBuffer& shaped_buffer); // Converts a GlobalDataHandle into a pointer to a ShapedBuffer that's valid // as long as the handle is valid. @@ -152,8 +149,8 @@ class LocalClient : public Client { // TODO(b/69670845): Remove the 'Local' from the name when LocalClient does // not inherit from Client and there is no possibility of confusion with // Client::TransferFromOutfeed. - StatusOr> TransferFromOutfeedLocal( - const Shape& shape, int device_ordinal); + StatusOr TransferFromOutfeedLocal(const Shape& shape, + int device_ordinal); // Returns the device ordinal that corresponds to the given replica number. // diff --git a/tensorflow/compiler/xla/client/padding.cc b/tensorflow/compiler/xla/client/padding.cc index 6a9cf466ac0a43ce214ef0e6aae9e6295f137b0f..992b13139c480900e7b983825be61ce88f14e11b 100644 --- a/tensorflow/compiler/xla/client/padding.cc +++ b/tensorflow/compiler/xla/client/padding.cc @@ -23,16 +23,15 @@ limitations under the License. namespace xla { -Status ValidatePaddingValues( - tensorflow::gtl::ArraySlice input_dimensions, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides) { +Status ValidatePaddingValues(absl::Span input_dimensions, + absl::Span window_dimensions, + absl::Span window_strides) { bool ok = input_dimensions.size() == window_dimensions.size() && input_dimensions.size() == window_strides.size(); if (!ok) { return InvalidArgument( - "Want input dimensions size %zu = window dimensions size %zu = window " - "strides size %zu", + "Want input dimensions size %u = window dimensions size %u = window " + "strides size %u", input_dimensions.size(), window_dimensions.size(), window_strides.size()); } @@ -40,9 +39,9 @@ Status ValidatePaddingValues( } std::vector> MakePadding( - tensorflow::gtl::ArraySlice input_dimensions, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, Padding padding) { + absl::Span input_dimensions, + absl::Span window_dimensions, + absl::Span window_strides, Padding padding) { TF_CHECK_OK(ValidatePaddingValues(input_dimensions, window_dimensions, window_strides)); std::vector> low_high_padding; diff --git a/tensorflow/compiler/xla/client/padding.h b/tensorflow/compiler/xla/client/padding.h index e23b0b3a90a091bf80973525810793c3eda4a036..5c009bd49e48b158550a32e64b0d63e2840dd1a9 100644 --- a/tensorflow/compiler/xla/client/padding.h +++ b/tensorflow/compiler/xla/client/padding.h @@ -19,9 +19,9 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/gtl/array_slice.h" namespace xla { @@ -41,10 +41,9 @@ enum class Padding { // Validates that the slices are acceptable for determining padding -- this can // be used to check the preconditions of MakePadding below to produce an error // message that can be returned to the user. -Status ValidatePaddingValues( - tensorflow::gtl::ArraySlice input_dimensions, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides); +Status ValidatePaddingValues(absl::Span input_dimensions, + absl::Span window_dimensions, + absl::Span window_strides); // Returns the padding needed for the base area, given the base area dimensions, // window dimensions, strides, and the type of padding. @@ -58,9 +57,9 @@ Status ValidatePaddingValues( // window_dimensions, and strides must match, which is equal to the number // of elements in the result vector. std::vector> MakePadding( - tensorflow::gtl::ArraySlice input_dimensions, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, Padding padding); + absl::Span input_dimensions, + absl::Span window_dimensions, + absl::Span window_strides, Padding padding); } // namespace xla diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index 9f902d7298cb1cc1da998580b01656c552ea8cbb..7d081b27222bd31ddbe7c64b4dea8a4d5a371acb 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" @@ -33,7 +34,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/shape_inference.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/mutex.h" namespace xla { @@ -72,7 +72,7 @@ XlaOp operator>>(const XlaOp& x, const XlaOp& y) { if (!ShapeUtil::ElementIsIntegral(shape)) { return InvalidArgument( "Argument to >> operator does not have an integral type (%s).", - ShapeUtil::HumanString(shape).c_str()); + ShapeUtil::HumanString(shape)); } if (ShapeUtil::ElementIsSigned(shape)) { return ShiftRightArithmetic(x, y); @@ -90,7 +90,7 @@ StatusOr XlaBuilder::GetShape(const XlaOp& op) const { } StatusOr> XlaBuilder::GetOperandShapes( - tensorflow::gtl::ArraySlice operands) const { + absl::Span operands) const { std::vector operand_shapes; for (const XlaOp& operand : operands) { TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(operand)); @@ -134,11 +134,12 @@ XlaOp XlaBuilder::ReportErrorOrReturn( StatusOr XlaBuilder::GetProgramShape(int64 root_id) const { TF_RETURN_IF_ERROR(first_error_); - TF_RET_CHECK((root_id >= 0) && (root_id < instructions_.size())); + TF_ASSIGN_OR_RETURN(const HloInstructionProto* root_proto, + LookUpInstructionByHandle(root_id)); ProgramShape program_shape; - *program_shape.mutable_result() = instructions_[root_id].shape(); + *program_shape.mutable_result() = root_proto->shape(); // Check that the parameter numbers are continuous from 0, and add parameter // shapes and names to the program shape. @@ -181,9 +182,8 @@ void XlaBuilder::IsConstantVisitor(const int64 op_handle, return; } - CHECK(op_handle < instructions_.size() && op_handle >= 0); - - const HloInstructionProto& instr = instructions_[op_handle]; + const HloInstructionProto& instr = + *(LookUpInstructionByHandle(op_handle).ValueOrDie()); const HloOpcode opcode = StringToHloOpcode(instr.opcode()).ValueOrDie(); switch (opcode) { default: @@ -208,6 +208,9 @@ void XlaBuilder::IsConstantVisitor(const int64 op_handle, case HloOpcode::kWhile: // TODO(b/32495713): We aren't checking the condition and body // computations themselves. + case HloOpcode::kScatter: + // TODO(b/32495713): We aren't checking the embedded computation in + // Scatter. case HloOpcode::kSend: case HloOpcode::kRecv: case HloOpcode::kParameter: @@ -275,7 +278,7 @@ StatusOr XlaBuilder::Build(int64 root_id) { module->set_id(entry.id()); module->set_entry_computation_name(entry.name()); module->set_entry_computation_id(entry.id()); - *module->mutable_program_shape() = entry.program_shape(); + *module->mutable_host_program_shape() = entry.program_shape(); for (auto& e : embedded_) { module->add_computations()->Swap(&e.second); } @@ -283,6 +286,7 @@ StatusOr XlaBuilder::Build(int64 root_id) { // Clear data held by this builder. this->instructions_.clear(); + this->handle_to_index_.clear(); this->embedded_.clear(); this->parameter_numbers_.clear(); @@ -291,7 +295,7 @@ StatusOr XlaBuilder::Build(int64 root_id) { StatusOr XlaBuilder::InDimBroadcast( const Shape& shape, const XlaOp& operand, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { TF_RETURN_IF_ERROR(first_error_); HloInstructionProto instr; @@ -352,9 +356,8 @@ XlaOp XlaBuilder::UnaryOp(HloOpcode unop, const XlaOp& operand) { }); } -XlaOp XlaBuilder::BinaryOp( - HloOpcode binop, const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { +XlaOp XlaBuilder::BinaryOp(HloOpcode binop, const XlaOp& lhs, const XlaOp& rhs, + absl::Span broadcast_dimensions) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); @@ -448,12 +451,12 @@ XlaOp XlaBuilder::TernaryOp(HloOpcode triop, const XlaOp& lhs, const XlaOp& rhs, } XlaOp XlaBuilder::Add(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return BinaryOp(HloOpcode::kAdd, lhs, rhs, broadcast_dimensions); } XlaOp XlaBuilder::Mul(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return BinaryOp(HloOpcode::kMultiply, lhs, rhs, broadcast_dimensions); } @@ -466,8 +469,21 @@ XlaOp XlaBuilder::ConstantLiteral(const LiteralSlice& literal) { }); } +XlaOp XlaBuilder::Iota(const Shape& shape, int64 iota_dimension) { + return ReportErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + *instr.mutable_shape() = shape; + instr.add_dimensions(iota_dimension); + return AddInstruction(std::move(instr), HloOpcode::kIota); + }); +} + +XlaOp XlaBuilder::Iota(PrimitiveType type, int64 size) { + return Iota(ShapeUtil::MakeShape(type, {size}), /*iota_dimension=*/0); +} + XlaOp XlaBuilder::Call(const XlaComputation& computation, - tensorflow::gtl::ArraySlice operands) { + absl::Span operands) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; std::vector operand_shape_ptrs; @@ -492,7 +508,7 @@ XlaOp XlaBuilder::Parameter(int64 parameter_number, const Shape& shape, return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; if (!parameter_numbers_.insert(parameter_number).second) { - return InvalidArgument("parameter %lld already registered", + return InvalidArgument("parameter %d already registered", parameter_number); } instr.set_parameter_number(parameter_number); @@ -502,8 +518,8 @@ XlaOp XlaBuilder::Parameter(int64 parameter_number, const Shape& shape, }); } -XlaOp XlaBuilder::Broadcast( - const XlaOp& operand, tensorflow::gtl::ArraySlice broadcast_sizes) { +XlaOp XlaBuilder::Broadcast(const XlaOp& operand, + absl::Span broadcast_sizes) { return ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); TF_ASSIGN_OR_RETURN( @@ -528,7 +544,7 @@ XlaOp XlaBuilder::Broadcast( XlaOp XlaBuilder::BroadcastInDim( const XlaOp& operand, const Shape& shape, - const tensorflow::gtl::ArraySlice broadcast_dimensions) { + const absl::Span broadcast_dimensions) { return ReportErrorOrReturn([&]() -> StatusOr { return InDimBroadcast(shape, operand, broadcast_dimensions); }); @@ -543,9 +559,9 @@ StatusOr XlaBuilder::Reshape(const Shape& shape, const XlaOp& operand) { } XlaOp XlaBuilder::Slice(const XlaOp& operand, - tensorflow::gtl::ArraySlice start_indices, - tensorflow::gtl::ArraySlice limit_indices, - tensorflow::gtl::ArraySlice strides) { + absl::Span start_indices, + absl::Span limit_indices, + absl::Span strides) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); @@ -580,7 +596,7 @@ XlaOp XlaBuilder::SliceInDim(const XlaOp& operand, int64 start_index, } XlaOp XlaBuilder::DynamicSlice(const XlaOp& operand, const XlaOp& start_indices, - tensorflow::gtl::ArraySlice slice_sizes) { + absl::Span slice_sizes) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; @@ -618,7 +634,7 @@ XlaOp XlaBuilder::DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, }); } -XlaOp XlaBuilder::ConcatInDim(tensorflow::gtl::ArraySlice operands, +XlaOp XlaBuilder::ConcatInDim(absl::Span operands, int64 dimension) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; @@ -658,8 +674,8 @@ XlaOp XlaBuilder::Pad(const XlaOp& operand, const XlaOp& padding_value, } XlaOp XlaBuilder::Reshape(const XlaOp& operand, - tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice new_sizes) { + absl::Span dimensions, + absl::Span new_sizes) { return ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); TF_ASSIGN_OR_RETURN(const Shape& shape, @@ -673,7 +689,7 @@ XlaOp XlaBuilder::Reshape(const XlaOp& operand, } XlaOp XlaBuilder::Reshape(const XlaOp& operand, - tensorflow::gtl::ArraySlice new_sizes) { + absl::Span new_sizes) { return ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(auto shape, GetShape(operand)); std::vector dimensions(shape.dimensions_size()); @@ -683,7 +699,7 @@ XlaOp XlaBuilder::Reshape(const XlaOp& operand, } XlaOp XlaBuilder::Collapse(const XlaOp& operand, - tensorflow::gtl::ArraySlice dimensions) { + absl::Span dimensions) { return ReportErrorOrReturn([&]() -> StatusOr { if (dimensions.size() <= 1) { // Not collapsing anything, trivially we can return the operand versus @@ -693,8 +709,7 @@ XlaOp XlaBuilder::Collapse(const XlaOp& operand, // Out-of-order collapse is not supported. // Checks that the collapsed dimensions are in order and consecutive. - for (tensorflow::gtl::ArraySlice::size_type i = 1; - i < dimensions.size(); ++i) { + for (absl::Span::size_type i = 1; i < dimensions.size(); ++i) { if (dimensions[i] - 1 != dimensions[i - 1]) { return InvalidArgument( "Collapsed dimensions are not in consecutive order."); @@ -727,7 +742,7 @@ void XlaBuilder::Trace(const string& tag, const XlaOp& operand) { ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; *instr.mutable_shape() = ShapeUtil::MakeNil(); - *instr.mutable_literal() = LiteralUtil::CreateR1U8(tag)->ToProto(); + *instr.mutable_literal() = LiteralUtil::CreateR1U8(tag).ToProto(); return AddInstruction(std::move(instr), HloOpcode::kTrace, {operand}); }); } @@ -745,7 +760,7 @@ XlaOp XlaBuilder::Select(const XlaOp& pred, const XlaOp& on_true, }); } -XlaOp XlaBuilder::Tuple(tensorflow::gtl::ArraySlice elements) { +XlaOp XlaBuilder::Tuple(absl::Span elements) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; std::vector operand_shape_ptrs; @@ -766,7 +781,7 @@ XlaOp XlaBuilder::GetTupleElement(const XlaOp& tuple_data, int64 index) { if (!ShapeUtil::IsTuple(tuple_shape)) { return InvalidArgument( "Operand to GetTupleElement() is not a tuple; got %s", - ShapeUtil::HumanString(tuple_shape).c_str()); + ShapeUtil::HumanString(tuple_shape)); } *instr.mutable_shape() = ShapeUtil::GetTupleElementShape(tuple_shape, index); @@ -779,37 +794,37 @@ XlaOp XlaBuilder::GetTupleElement(const XlaOp& tuple_data, int64 index) { } XlaOp XlaBuilder::Eq(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return BinaryOp(HloOpcode::kEq, lhs, rhs, broadcast_dimensions); } XlaOp XlaBuilder::Ne(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return BinaryOp(HloOpcode::kNe, lhs, rhs, broadcast_dimensions); } XlaOp XlaBuilder::Ge(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return BinaryOp(HloOpcode::kGe, lhs, rhs, broadcast_dimensions); } XlaOp XlaBuilder::Gt(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return BinaryOp(HloOpcode::kGt, lhs, rhs, broadcast_dimensions); } XlaOp XlaBuilder::Le(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return BinaryOp(HloOpcode::kLe, lhs, rhs, broadcast_dimensions); } XlaOp XlaBuilder::Lt(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return BinaryOp(HloOpcode::kLt, lhs, rhs, broadcast_dimensions); } XlaOp XlaBuilder::Dot(const XlaOp& lhs, const XlaOp& rhs, - const PrecisionConfigProto* precision_config_proto) { + const PrecisionConfig* precision_config) { return ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); @@ -817,14 +832,13 @@ XlaOp XlaBuilder::Dot(const XlaOp& lhs, const XlaOp& rhs, dimension_numbers.add_lhs_contracting_dimensions( lhs_shape.dimensions_size() == 1 ? 0 : 1); dimension_numbers.add_rhs_contracting_dimensions(0); - return DotGeneral(lhs, rhs, dimension_numbers, precision_config_proto); + return DotGeneral(lhs, rhs, dimension_numbers, precision_config); }); } -XlaOp XlaBuilder::DotGeneral( - const XlaOp& lhs, const XlaOp& rhs, - const DotDimensionNumbers& dimension_numbers, - const PrecisionConfigProto* precision_config_proto) { +XlaOp XlaBuilder::DotGeneral(const XlaOp& lhs, const XlaOp& rhs, + const DotDimensionNumbers& dimension_numbers, + const PrecisionConfig* precision_config) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); @@ -833,8 +847,8 @@ XlaOp XlaBuilder::DotGeneral( ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dimension_numbers)); *instr.mutable_dot_dimension_numbers() = dimension_numbers; - if (precision_config_proto != nullptr) { - *instr.mutable_precision_config() = *precision_config_proto; + if (precision_config != nullptr) { + *instr.mutable_precision_config() = *precision_config; } return AddInstruction(std::move(instr), HloOpcode::kDot, {lhs, rhs}); }); @@ -847,16 +861,14 @@ Status XlaBuilder::VerifyConvolution( return InvalidArgument( "Convolution arguments must have same number of " "dimensions. Got: %s and %s", - ShapeUtil::HumanString(lhs_shape).c_str(), - ShapeUtil::HumanString(rhs_shape).c_str()); + ShapeUtil::HumanString(lhs_shape), ShapeUtil::HumanString(rhs_shape)); } int num_dims = ShapeUtil::Rank(lhs_shape); if (num_dims < 2) { return InvalidArgument( "Convolution expects argument arrays with >= 3 dimensions. " "Got: %s and %s", - ShapeUtil::HumanString(lhs_shape).c_str(), - ShapeUtil::HumanString(rhs_shape).c_str()); + ShapeUtil::HumanString(lhs_shape), ShapeUtil::HumanString(rhs_shape)); } int num_spatial_dims = num_dims - 2; @@ -870,7 +882,7 @@ Status XlaBuilder::VerifyConvolution( } for (int i = 0; i < numbers.size(); ++i) { if (numbers.Get(i) < 0 || numbers.Get(i) >= num_dims) { - return InvalidArgument("Convolution %s[%d] is out of bounds: %lld", + return InvalidArgument("Convolution %s[%d] is out of bounds: %d", field_name, i, numbers.Get(i)); } } @@ -888,32 +900,28 @@ Status XlaBuilder::VerifyConvolution( } XlaOp XlaBuilder::Conv(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice window_strides, - Padding padding, int64 feature_group_count, - const PrecisionConfigProto* precision_config_proto) { + absl::Span window_strides, Padding padding, + int64 feature_group_count, + const PrecisionConfig* precision_config) { return ConvWithGeneralDimensions( lhs, rhs, window_strides, padding, CreateDefaultConvDimensionNumbers(window_strides.size()), - feature_group_count, precision_config_proto); + feature_group_count, precision_config); } XlaOp XlaBuilder::ConvWithGeneralPadding( - const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - int64 feature_group_count, - const PrecisionConfigProto* precision_config_proto) { + const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, + absl::Span> padding, + int64 feature_group_count, const PrecisionConfig* precision_config) { return ConvGeneral(lhs, rhs, window_strides, padding, CreateDefaultConvDimensionNumbers(window_strides.size()), - feature_group_count, precision_config_proto); + feature_group_count, precision_config); } XlaOp XlaBuilder::ConvWithGeneralDimensions( - const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice window_strides, Padding padding, - const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count, - const PrecisionConfigProto* precision_config_proto) { + const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, + Padding padding, const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count, const PrecisionConfig* precision_config) { return ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs)); @@ -941,31 +949,26 @@ XlaOp XlaBuilder::ConvWithGeneralDimensions( MakePadding(base_area_dimensions, window_dimensions, window_strides, padding), dimension_numbers, feature_group_count, - precision_config_proto); + precision_config); }); } XlaOp XlaBuilder::ConvGeneral( - const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, + const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, + absl::Span> padding, const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count, - const PrecisionConfigProto* precision_config_proto) { + int64 feature_group_count, const PrecisionConfig* precision_config) { return ConvGeneralDilated(lhs, rhs, window_strides, padding, {}, {}, dimension_numbers, feature_group_count, - precision_config_proto); + precision_config); } XlaOp XlaBuilder::ConvGeneralDilated( - const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - tensorflow::gtl::ArraySlice lhs_dilation, - tensorflow::gtl::ArraySlice rhs_dilation, + const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, + absl::Span> padding, + absl::Span lhs_dilation, absl::Span rhs_dilation, const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count, - const PrecisionConfigProto* precision_config_proto) { + int64 feature_group_count, const PrecisionConfig* precision_config) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); @@ -986,14 +989,14 @@ XlaOp XlaBuilder::ConvGeneralDilated( TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), ShapeInference::InferConvolveShape( - lhs_shape, rhs_shape, instr.window(), - dimension_numbers, feature_group_count)); + lhs_shape, rhs_shape, feature_group_count, + instr.window(), dimension_numbers)); *instr.mutable_convolution_dimension_numbers() = dimension_numbers; instr.set_feature_group_count(feature_group_count); - if (precision_config_proto != nullptr) { - *instr.mutable_precision_config() = *precision_config_proto; + if (precision_config != nullptr) { + *instr.mutable_precision_config() = *precision_config; } return AddInstruction(std::move(instr), HloOpcode::kConvolution, @@ -1002,11 +1005,11 @@ XlaOp XlaBuilder::ConvGeneralDilated( } StatusOr XlaBuilder::MakeWindow( - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - tensorflow::gtl::ArraySlice lhs_dilation, - tensorflow::gtl::ArraySlice rhs_dilation) const { + absl::Span window_dimensions, + absl::Span window_strides, + absl::Span> padding, + absl::Span lhs_dilation, + absl::Span rhs_dilation) const { const auto verify_size = [&](const size_t x, const char* x_name) { if (x == 0 || x == window_dimensions.size()) { return Status::OK(); @@ -1016,8 +1019,7 @@ StatusOr XlaBuilder::MakeWindow( "Window has different number of window dimensions than of ", x_name, "\nNumber of window dimensions: ", window_dimensions.size(), - "\nNumber of ", x_name, ": ", x, "\n") - .c_str()); + "\nNumber of ", x_name, ": ", x, "\n")); } }; TF_RETURN_IF_ERROR(verify_size(window_strides.size(), "window strides")); @@ -1057,7 +1059,7 @@ StatusOr XlaBuilder::MakeWindow( } XlaOp XlaBuilder::Fft(const XlaOp& operand, const FftType fft_type, - const tensorflow::gtl::ArraySlice fft_length) { + const absl::Span fft_length) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); @@ -1193,8 +1195,8 @@ void XlaBuilder::Outfeed(const XlaOp& operand, const Shape& shape_with_layout, if (!ShapeUtil::Compatible(operand_shape, shape_with_layout)) { return InvalidArgument( "Outfeed shape %s must be compatible with operand shape %s", - ShapeUtil::HumanStringWithLayout(shape_with_layout).c_str(), - ShapeUtil::HumanStringWithLayout(operand_shape).c_str()); + ShapeUtil::HumanStringWithLayout(shape_with_layout), + ShapeUtil::HumanStringWithLayout(operand_shape)); } *instr.mutable_outfeed_shape() = shape_with_layout; @@ -1246,8 +1248,8 @@ XlaOp XlaBuilder::OutfeedWithToken(const XlaOp& operand, const XlaOp& token, if (!ShapeUtil::Compatible(operand_shape, shape_with_layout)) { return InvalidArgument( "Outfeed shape %s must be compatible with operand shape %s", - ShapeUtil::HumanStringWithLayout(shape_with_layout).c_str(), - ShapeUtil::HumanStringWithLayout(operand_shape).c_str()); + ShapeUtil::HumanStringWithLayout(shape_with_layout), + ShapeUtil::HumanStringWithLayout(operand_shape)); } *instr.mutable_outfeed_shape() = shape_with_layout; @@ -1266,7 +1268,7 @@ XlaOp XlaBuilder::CreateToken() { }); } -XlaOp XlaBuilder::AfterAll(tensorflow::gtl::ArraySlice tokens) { +XlaOp XlaBuilder::AfterAll(absl::Span tokens) { return ReportErrorOrReturn([&]() -> StatusOr { if (tokens.empty()) { return InvalidArgument("AfterAll requires at least one operand"); @@ -1277,26 +1279,52 @@ XlaOp XlaBuilder::AfterAll(tensorflow::gtl::ArraySlice tokens) { }); } -XlaOp XlaBuilder::CustomCall(const string& call_target_name, - tensorflow::gtl::ArraySlice operands, - const Shape& shape) { +XlaOp XlaBuilder::CustomCall( + const string& call_target_name, absl::Span operands, + const Shape& shape, const string& opaque, + absl::optional> operand_shapes_with_layout) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; if (absl::StartsWith(call_target_name, "$")) { return InvalidArgument( "Invalid custom_call_target \"%s\": Call targets that start with '$' " "are reserved for internal use.", - call_target_name.c_str()); + call_target_name); } *instr.mutable_shape() = shape; instr.set_custom_call_target(call_target_name); + instr.set_custom_call_opaque(opaque); + if (operand_shapes_with_layout.has_value()) { + if (!LayoutUtil::HasLayout(shape)) { + return InvalidArgument( + "Result shape must have layout for custom call with constrained " + "layout."); + } + if (operands.size() != operand_shapes_with_layout->size()) { + return InvalidArgument( + "Must specify a shape with layout for each operand for custom call " + "with constrained layout; given %d shapes, expected %d", + operand_shapes_with_layout->size(), operands.size()); + } + instr.set_constrain_layout(true); + int64 operand_num = 0; + for (const Shape& operand_shape : *operand_shapes_with_layout) { + if (!LayoutUtil::HasLayout(operand_shape)) { + return InvalidArgument( + "No layout specified for operand %d for custom call with " + "constrained layout.", + operand_num); + } + *instr.add_operand_shapes_with_layout() = operand_shape; + ++operand_num; + } + } return AddInstruction(std::move(instr), HloOpcode::kCustomCall, operands); }); } -XlaOp XlaBuilder::Complex( - const XlaOp& real, const XlaOp& imag, - tensorflow::gtl::ArraySlice broadcast_dimensions) { +XlaOp XlaBuilder::Complex(const XlaOp& real, const XlaOp& imag, + absl::Span broadcast_dimensions) { return BinaryOp(HloOpcode::kComplex, real, imag, broadcast_dimensions); } @@ -1305,42 +1333,42 @@ XlaOp XlaBuilder::Conj(const XlaOp& operand) { } XlaOp XlaBuilder::Sub(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return BinaryOp(HloOpcode::kSubtract, lhs, rhs, broadcast_dimensions); } XlaOp XlaBuilder::Div(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return BinaryOp(HloOpcode::kDivide, lhs, rhs, broadcast_dimensions); } XlaOp XlaBuilder::Rem(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return BinaryOp(HloOpcode::kRemainder, lhs, rhs, broadcast_dimensions); } XlaOp XlaBuilder::Max(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return BinaryOp(HloOpcode::kMaximum, lhs, rhs, broadcast_dimensions); } XlaOp XlaBuilder::Min(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return BinaryOp(HloOpcode::kMinimum, lhs, rhs, broadcast_dimensions); } XlaOp XlaBuilder::And(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return BinaryOp(HloOpcode::kAnd, lhs, rhs, broadcast_dimensions); } XlaOp XlaBuilder::Or(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return BinaryOp(HloOpcode::kOr, lhs, rhs, broadcast_dimensions); } XlaOp XlaBuilder::Xor(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return BinaryOp(HloOpcode::kXor, lhs, rhs, broadcast_dimensions); } @@ -1348,22 +1376,21 @@ XlaOp XlaBuilder::Not(const XlaOp& operand) { return UnaryOp(HloOpcode::kNot, operand); } -XlaOp XlaBuilder::ShiftLeft( - const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { +XlaOp XlaBuilder::ShiftLeft(const XlaOp& lhs, const XlaOp& rhs, + absl::Span broadcast_dimensions) { return BinaryOp(HloOpcode::kShiftLeft, lhs, rhs, broadcast_dimensions); } XlaOp XlaBuilder::ShiftRightArithmetic( const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return BinaryOp(HloOpcode::kShiftRightArithmetic, lhs, rhs, broadcast_dimensions); } XlaOp XlaBuilder::ShiftRightLogical( const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return BinaryOp(HloOpcode::kShiftRightLogical, lhs, rhs, broadcast_dimensions); } @@ -1372,9 +1399,8 @@ XlaOp XlaBuilder::Abs(const XlaOp& operand) { return UnaryOp(HloOpcode::kAbs, operand); } -XlaOp XlaBuilder::Atan2( - const XlaOp& y, const XlaOp& x, - tensorflow::gtl::ArraySlice broadcast_dimensions) { +XlaOp XlaBuilder::Atan2(const XlaOp& y, const XlaOp& x, + absl::Span broadcast_dimensions) { return BinaryOp(HloOpcode::kAtan2, y, x, broadcast_dimensions); } @@ -1439,7 +1465,7 @@ XlaOp XlaBuilder::IsFinite(const XlaOp& operand) { } XlaOp XlaBuilder::Transpose(const XlaOp& operand, - tensorflow::gtl::ArraySlice permutation) { + absl::Span permutation) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); @@ -1454,7 +1480,7 @@ XlaOp XlaBuilder::Transpose(const XlaOp& operand, } XlaOp XlaBuilder::Rev(const XlaOp& operand, - tensorflow::gtl::ArraySlice dimensions) { + absl::Span dimensions) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); @@ -1468,18 +1494,17 @@ XlaOp XlaBuilder::Rev(const XlaOp& operand, }); } -XlaOp XlaBuilder::Sort(XlaOp keys, absl::optional values, +XlaOp XlaBuilder::Sort(const XlaOp& keys, absl::Span values, int64 dimension) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; std::vector operand_shape_ptrs; TF_ASSIGN_OR_RETURN(const Shape& keys_shape, GetShape(keys)); operand_shape_ptrs.push_back(&keys_shape); - Shape values_shape; - if (values.has_value()) { - TF_ASSIGN_OR_RETURN(values_shape, GetShape(*values)); - operand_shape_ptrs.push_back(&values_shape); - } + TF_ASSIGN_OR_RETURN(std::vector values_shapes, + GetOperandShapes(values)); + absl::c_transform(values_shapes, std::back_inserter(operand_shape_ptrs), + [](const Shape& shape) { return &shape; }); TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), ShapeInference::InferVariadicOpShape( HloOpcode::kSort, operand_shape_ptrs)); @@ -1488,15 +1513,14 @@ XlaOp XlaBuilder::Sort(XlaOp keys, absl::optional values, dimension = ShapeUtil::Rank(keys_shape) - 1; } instr.add_dimensions(dimension); - return values.has_value() - ? AddInstruction(std::move(instr), HloOpcode::kSort, - {keys, *values}) - : AddInstruction(std::move(instr), HloOpcode::kSort, {keys}); + std::vector operands{keys}; + operands.insert(operands.end(), values.begin(), values.end()); + return AddInstruction(std::move(instr), HloOpcode::kSort, operands); }); } XlaOp XlaBuilder::Pow(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return BinaryOp(HloOpcode::kPower, lhs, rhs, broadcast_dimensions); } @@ -1534,10 +1558,10 @@ XlaOp XlaBuilder::Clamp(const XlaOp& min, const XlaOp& operand, return TernaryOp(HloOpcode::kClamp, min, operand, max); } -XlaOp XlaBuilder::Map(tensorflow::gtl::ArraySlice operands, +XlaOp XlaBuilder::Map(absl::Span operands, const XlaComputation& computation, - tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice static_operands) { + absl::Span dimensions, + absl::Span static_operands) { return ReportErrorOrReturn([&]() -> StatusOr { if (!static_operands.empty()) { return Unimplemented("static_operands is not supported in Map"); @@ -1578,7 +1602,7 @@ XlaOp XlaBuilder::Map(tensorflow::gtl::ArraySlice operands, } XlaOp XlaBuilder::RngOp(RandomDistribution distribution, - tensorflow::gtl::ArraySlice parameters, + absl::Span parameters, const Shape& shape) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; @@ -1590,7 +1614,7 @@ XlaOp XlaBuilder::RngOp(RandomDistribution distribution, if (parameters.size() != 2) { return InvalidArgument( "RNG distribution (%s) expects 2 parameters, but got %ld", - RandomDistribution_Name(distribution).c_str(), parameters.size()); + RandomDistribution_Name(distribution), parameters.size()); } break; default: @@ -1639,7 +1663,7 @@ XlaOp XlaBuilder::While(const XlaComputation& condition, XlaOp XlaBuilder::Gather(const XlaOp& input, const XlaOp& start_indices, const GatherDimensionNumbers& dimension_numbers, - tensorflow::gtl::ArraySlice slice_sizes) { + absl::Span slice_sizes) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; @@ -1719,22 +1743,39 @@ XlaOp XlaBuilder::Conditional(const XlaOp& predicate, const XlaOp& true_operand, }); } -XlaOp XlaBuilder::Reduce( - const XlaOp& operand, const XlaOp& init_value, - const XlaComputation& computation, - tensorflow::gtl::ArraySlice dimensions_to_reduce) { +XlaOp XlaBuilder::Reduce(const XlaOp& operand, const XlaOp& init_value, + const XlaComputation& computation, + absl::Span dimensions_to_reduce) { + return Reduce(absl::Span({operand}), + absl::Span({init_value}), computation, + dimensions_to_reduce); +} + +XlaOp XlaBuilder::Reduce(absl::Span operands, + absl::Span init_values, + const XlaComputation& computation, + absl::Span dimensions_to_reduce) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; - TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); - TF_ASSIGN_OR_RETURN(const Shape& init_shape, GetShape(init_value)); TF_ASSIGN_OR_RETURN(const ProgramShape& called_program_shape, computation.GetProgramShape()); - TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), - ShapeInference::InferReduceShape( - {&operand_shape, &init_shape}, dimensions_to_reduce, - called_program_shape)); + std::vector all_operands; + all_operands.insert(all_operands.end(), operands.begin(), operands.end()); + all_operands.insert(all_operands.end(), init_values.begin(), + init_values.end()); + + std::vector operand_shape_ptrs; + TF_ASSIGN_OR_RETURN(const auto& operand_shapes, + GetOperandShapes(all_operands)); + absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs), + [](const Shape& shape) { return &shape; }); + + TF_ASSIGN_OR_RETURN( + *instr.mutable_shape(), + ShapeInference::InferReduceShape( + operand_shape_ptrs, dimensions_to_reduce, called_program_shape)); for (int64 dim : dimensions_to_reduce) { instr.add_dimensions(dim); @@ -1742,8 +1783,7 @@ XlaOp XlaBuilder::Reduce( AddCalledComputation(computation, &instr); - return AddInstruction(std::move(instr), HloOpcode::kReduce, - {operand, init_value}); + return AddInstruction(std::move(instr), HloOpcode::kReduce, all_operands); }); } @@ -1757,11 +1797,11 @@ XlaOp XlaBuilder::ReduceAll(const XlaOp& operand, const XlaOp& init_value, }); } -XlaOp XlaBuilder::ReduceWindow( - const XlaOp& operand, const XlaOp& init_value, - const XlaComputation& computation, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, Padding padding) { +XlaOp XlaBuilder::ReduceWindow(const XlaOp& operand, const XlaOp& init_value, + const XlaComputation& computation, + absl::Span window_dimensions, + absl::Span window_strides, + Padding padding) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; @@ -1773,18 +1813,20 @@ XlaOp XlaBuilder::ReduceWindow( std::vector> padding_values = MakePadding(AsInt64Slice(operand_shape.dimensions()), window_dimensions, window_strides, padding); - return ReduceWindowWithGeneralPadding(operand, init_value, computation, - window_dimensions, window_strides, - padding_values); + return ReduceWindowWithGeneralPadding( + operand, init_value, computation, window_dimensions, window_strides, + /*base_dilations=*/{}, /*window_dilations=*/{}, padding_values); }); } XlaOp XlaBuilder::ReduceWindowWithGeneralPadding( const XlaOp& operand, const XlaOp& init_value, const XlaComputation& computation, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding) { + absl::Span window_dimensions, + absl::Span window_strides, + absl::Span base_dilations, + absl::Span window_dilations, + absl::Span> padding) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; @@ -1794,7 +1836,8 @@ XlaOp XlaBuilder::ReduceWindowWithGeneralPadding( computation.GetProgramShape()); TF_ASSIGN_OR_RETURN(*instr.mutable_window(), MakeWindow(window_dimensions, window_strides, padding, - /*lhs_dilation=*/{}, /*rhs_dilation=*/{})); + /*lhs_dilation=*/base_dilations, + /*rhs_dilation=*/window_dilations)); TF_ASSIGN_OR_RETURN( *instr.mutable_shape(), ShapeInference::InferReduceWindowShape(operand_shape, init_shape, @@ -1879,8 +1922,7 @@ XlaOp XlaBuilder::BatchNormGrad(const XlaOp& operand, const XlaOp& scale, } XlaOp XlaBuilder::CrossReplicaSum( - const XlaOp& operand, - tensorflow::gtl::ArraySlice replica_groups) { + const XlaOp& operand, absl::Span replica_groups) { return ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(operand)); const Shape& scalar_shape = ShapeUtil::MakeShape(shape.element_type(), {}); @@ -1895,7 +1937,7 @@ XlaOp XlaBuilder::CrossReplicaSum( XlaOp XlaBuilder::CrossReplicaSum( const XlaOp& operand, const XlaComputation& computation, - tensorflow::gtl::ArraySlice replica_groups, + absl::Span replica_groups, const absl::optional& channel_id) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; @@ -1974,12 +2016,34 @@ XlaOp XlaBuilder::AllToAll(const XlaOp& operand, int64 split_dimension, }); } -XlaOp XlaBuilder::SelectAndScatter( - const XlaOp& operand, const XlaComputation& select, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, Padding padding, - const XlaOp& source, const XlaOp& init_value, - const XlaComputation& scatter) { +XlaOp XlaBuilder::CollectivePermute( + const XlaOp& operand, + const std::vector>& source_target_pairs) { + return ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); + HloInstructionProto instr; + TF_ASSIGN_OR_RETURN( + *instr.mutable_shape(), + ShapeInference::InferCollectivePermuteShape(operand_shape)); + + for (const auto& pair : source_target_pairs) { + auto* proto_pair = instr.add_source_target_pairs(); + proto_pair->set_source(pair.first); + proto_pair->set_target(pair.second); + } + + return AddInstruction(std::move(instr), HloOpcode::kCollectivePermute, + {operand}); + }); +} + +XlaOp XlaBuilder::SelectAndScatter(const XlaOp& operand, + const XlaComputation& select, + absl::Span window_dimensions, + absl::Span window_strides, + Padding padding, const XlaOp& source, + const XlaOp& init_value, + const XlaComputation& scatter) { return ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); return SelectAndScatterWithGeneralPadding( @@ -1992,11 +2056,10 @@ XlaOp XlaBuilder::SelectAndScatter( XlaOp XlaBuilder::SelectAndScatterWithGeneralPadding( const XlaOp& operand, const XlaComputation& select, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - const XlaOp& source, const XlaOp& init_value, - const XlaComputation& scatter) { + absl::Span window_dimensions, + absl::Span window_strides, + absl::Span> padding, const XlaOp& source, + const XlaOp& init_value, const XlaComputation& scatter) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; @@ -2140,13 +2203,13 @@ XlaOp XlaBuilder::SendToHost(const XlaOp& operand, const XlaOp& token, if (!ShapeUtil::Compatible(operand_shape, shape_with_layout)) { return InvalidArgument( "SendToHost shape %s must be compatible with operand shape %s", - ShapeUtil::HumanStringWithLayout(shape_with_layout).c_str(), - ShapeUtil::HumanStringWithLayout(operand_shape).c_str()); + ShapeUtil::HumanStringWithLayout(shape_with_layout), + ShapeUtil::HumanStringWithLayout(operand_shape)); } // TODO(b/111544877): Support tuple shapes. if (!ShapeUtil::IsArray(operand_shape)) { return InvalidArgument("SendToHost only supports array shapes, shape: %s", - ShapeUtil::HumanString(operand_shape).c_str()); + ShapeUtil::HumanString(operand_shape)); } if (handle.type() != ChannelHandle::DEVICE_TO_HOST) { @@ -2185,7 +2248,7 @@ XlaOp XlaBuilder::RecvFromHost(const XlaOp& token, const Shape& shape, if (!ShapeUtil::IsArray(shape)) { return InvalidArgument( "RecvFromHost only supports array shapes, shape: %s", - ShapeUtil::HumanString(shape).c_str()); + ShapeUtil::HumanString(shape)); } if (handle.type() != ChannelHandle::HOST_TO_DEVICE) { @@ -2240,7 +2303,7 @@ StatusOr XlaBuilder::BuildConstantSubGraph( "of being evaluated at XLA compile time.\n\n" "Please file a usability bug with the framework being used (e.g. " "TensorFlow).", - op_string.c_str()); + op_string); } TF_ASSIGN_OR_RETURN(const HloInstructionProto* root, @@ -2254,22 +2317,24 @@ StatusOr XlaBuilder::BuildConstantSubGraph( *program_shape->mutable_result() = root->shape(); // We use std::set to keep the instruction ids in ascending order (which is - // also a valid denpendency order). The related ops will be added to the + // also a valid dependency order). The related ops will be added to the // subgraph in the same order. std::set related_ops; - tensorflow::gtl::FlatSet related_calls; // Related computations. + absl::flat_hash_set related_calls; // Related computations. std::queue worklist; worklist.push(root->id()); related_ops.insert(root->id()); while (!worklist.empty()) { - int64 node = worklist.front(); + int64 handle = worklist.front(); worklist.pop(); - for (int64 id : instructions_[node].operand_ids()) { + TF_ASSIGN_OR_RETURN(const HloInstructionProto* instr_proto, + LookUpInstructionByHandle(handle)); + for (int64 id : instr_proto->operand_ids()) { if (related_ops.insert(id).second) { worklist.push(id); } } - for (int64 called_id : instructions_[node].called_computation_ids()) { + for (int64 called_id : instr_proto->called_computation_ids()) { related_calls.insert(called_id); } } @@ -2277,7 +2342,9 @@ StatusOr XlaBuilder::BuildConstantSubGraph( // Add related ops to the computation. for (int64 id : related_ops) { auto* instr = entry.add_instructions(); - *instr = instructions_[id]; + TF_ASSIGN_OR_RETURN(const HloInstructionProto* instr_src, + LookUpInstructionByHandle(id)); + *instr = *instr_src; // Ensures that the instruction names are unique among the graph. const string& new_name = StrCat(instr->name(), ".", entry.id(), ".", instr->id()); @@ -2290,7 +2357,7 @@ StatusOr XlaBuilder::BuildConstantSubGraph( module->set_id(entry.id()); module->set_entry_computation_name(entry.name()); module->set_entry_computation_id(entry.id()); - *module->mutable_program_shape() = *program_shape; + *module->mutable_host_program_shape() = *program_shape; for (auto& e : embedded_) { if (related_calls.find(e.second.id()) != related_calls.end()) { *module->add_computations() = e.second; @@ -2348,8 +2415,8 @@ XlaBuilder::CreateDefaultConvDimensionNumbers(int num_spatial_dims) { dnum.input_spatial_dimensions(0), dnum.input_spatial_dimensions(1)}) .size() != 4) { return FailedPrecondition( - "dimension numbers for the input are not unique: (%lld, %lld, %lld, " - "%lld)", + "dimension numbers for the input are not unique: (%d, %d, %d, " + "%d)", dnum.input_batch_dimension(), dnum.input_feature_dimension(), dnum.input_spatial_dimensions(0), dnum.input_spatial_dimensions(1)); } @@ -2359,8 +2426,8 @@ XlaBuilder::CreateDefaultConvDimensionNumbers(int num_spatial_dims) { dnum.kernel_spatial_dimensions(1)}) .size() != 4) { return FailedPrecondition( - "dimension numbers for the weight are not unique: (%lld, %lld, %lld, " - "%lld)", + "dimension numbers for the weight are not unique: (%d, %d, %d, " + "%d)", dnum.kernel_output_feature_dimension(), dnum.kernel_input_feature_dimension(), dnum.kernel_spatial_dimensions(0), dnum.kernel_spatial_dimensions(1)); @@ -2371,34 +2438,32 @@ XlaBuilder::CreateDefaultConvDimensionNumbers(int num_spatial_dims) { dnum.output_spatial_dimensions(1)}) .size() != 4) { return FailedPrecondition( - "dimension numbers for the output are not unique: (%lld, %lld, %lld, " - "%lld)", + "dimension numbers for the output are not unique: (%d, %d, %d, " + "%d)", dnum.output_batch_dimension(), dnum.output_feature_dimension(), dnum.output_spatial_dimensions(0), dnum.output_spatial_dimensions(1)); } return Status::OK(); } -StatusOr XlaBuilder::AddInstruction( - HloInstructionProto&& instr, HloOpcode opcode, - tensorflow::gtl::ArraySlice operands) { +StatusOr XlaBuilder::AddInstruction(HloInstructionProto&& instr, + HloOpcode opcode, + absl::Span operands) { TF_RETURN_IF_ERROR(first_error_); - const int64 handle = instructions_.size(); + const int64 handle = GetUniqueId(); instr.set_id(handle); instr.set_opcode(HloOpcodeString(opcode)); if (instr.name().empty()) { - instr.set_name(StrCat(instr.opcode())); + instr.set_name(instr.opcode()); } for (const auto& operand : operands) { if (operand.builder_ == nullptr) { - return InvalidArgument("invalid XlaOp with handle %lld", - operand.handle()); + return InvalidArgument("invalid XlaOp with handle %d", operand.handle()); } if (operand.builder_ != this) { return InvalidArgument("Do not add XlaOp from builder %s to builder %s", - operand.builder_->name().c_str(), - this->name().c_str()); + operand.builder_->name(), this->name()); } instr.add_operand_ids(operand.handle()); } @@ -2408,7 +2473,8 @@ StatusOr XlaBuilder::AddInstruction( *instr.mutable_sharding() = *sharding_; } - instructions_.push_back(instr); + handle_to_index_[handle] = instructions_.size(); + instructions_.push_back(std::move(instr)); XlaOp op(handle, this); return op; @@ -2428,20 +2494,26 @@ StatusOr XlaBuilder::LookUpInstruction( if (op.builder_ == nullptr) { return InvalidArgument( - "invalid XlaOp with handle %lld; the builder of this op is freed", + "invalid XlaOp with handle %d; the builder of this op is freed", op.handle()); } if (op.builder_ != this) { return InvalidArgument( - "XlaOp with handle %lld is built by builder '%s', but is trying to use " + "XlaOp with handle %d is built by builder '%s', but is trying to use " "it in builder '%s'", - op.handle(), op.builder_->name().c_str(), this->name().c_str()); + op.handle(), op.builder_->name(), this->name()); } - if (op.handle() >= instructions_.size() || op.handle() < 0) { - return InvalidArgument("no XlaOp value %lld", op.handle()); + return LookUpInstructionByHandle(op.handle()); +} + +StatusOr XlaBuilder::LookUpInstructionByHandle( + int64 handle) const { + auto it = handle_to_index_.find(handle); + if (it == handle_to_index_.end()) { + return InvalidArgument("No XlaOp with handle %d", handle); } - return &instructions_[op.handle()]; + return &instructions_[it->second]; } // Enqueues a "retrieve parameter value" instruction for a parameter that was @@ -2457,14 +2529,12 @@ XlaOp ConstantLiteral(XlaBuilder* builder, const LiteralSlice& literal) { return builder->ConstantLiteral(literal); } -XlaOp Broadcast(const XlaOp& operand, - tensorflow::gtl::ArraySlice broadcast_sizes) { +XlaOp Broadcast(const XlaOp& operand, absl::Span broadcast_sizes) { return operand.builder()->Broadcast(operand, broadcast_sizes); } -XlaOp BroadcastInDim( - const XlaOp& operand, const Shape& shape, - const tensorflow::gtl::ArraySlice broadcast_dimensions) { +XlaOp BroadcastInDim(const XlaOp& operand, const Shape& shape, + const absl::Span broadcast_dimensions) { return operand.builder()->BroadcastInDim(operand, shape, broadcast_dimensions); } @@ -2474,26 +2544,22 @@ XlaOp Pad(const XlaOp& operand, const XlaOp& padding_value, return operand.builder()->Pad(operand, padding_value, padding_config); } -XlaOp Reshape(const XlaOp& operand, - tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice new_sizes) { +XlaOp Reshape(const XlaOp& operand, absl::Span dimensions, + absl::Span new_sizes) { return operand.builder()->Reshape(operand, dimensions, new_sizes); } -XlaOp Reshape(const XlaOp& operand, - tensorflow::gtl::ArraySlice new_sizes) { +XlaOp Reshape(const XlaOp& operand, absl::Span new_sizes) { return operand.builder()->Reshape(operand, new_sizes); } -XlaOp Collapse(const XlaOp& operand, - tensorflow::gtl::ArraySlice dimensions) { +XlaOp Collapse(const XlaOp& operand, absl::Span dimensions) { return operand.builder()->Collapse(operand, dimensions); } -XlaOp Slice(const XlaOp& operand, - tensorflow::gtl::ArraySlice start_indices, - tensorflow::gtl::ArraySlice limit_indices, - tensorflow::gtl::ArraySlice strides) { +XlaOp Slice(const XlaOp& operand, absl::Span start_indices, + absl::Span limit_indices, + absl::Span strides) { return operand.builder()->Slice(operand, start_indices, limit_indices, strides); } @@ -2505,7 +2571,7 @@ XlaOp SliceInDim(const XlaOp& operand, int64 start_index, int64 limit_index, } XlaOp DynamicSlice(const XlaOp& operand, const XlaOp& start_indices, - tensorflow::gtl::ArraySlice slice_sizes) { + absl::Span slice_sizes) { return operand.builder()->DynamicSlice(operand, start_indices, slice_sizes); } @@ -2514,8 +2580,7 @@ XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, return operand.builder()->DynamicUpdateSlice(operand, update, start_indices); } -XlaOp ConcatInDim(XlaBuilder* builder, - tensorflow::gtl::ArraySlice operands, +XlaOp ConcatInDim(XlaBuilder* builder, absl::Span operands, int64 dimension) { return builder->ConcatInDim(operands, dimension); } @@ -2528,7 +2593,7 @@ XlaOp Select(const XlaOp& pred, const XlaOp& on_true, const XlaOp& on_false) { return pred.builder()->Select(pred, on_true, on_false); } -XlaOp Tuple(XlaBuilder* builder, tensorflow::gtl::ArraySlice elements) { +XlaOp Tuple(XlaBuilder* builder, absl::Span elements) { return builder->Tuple(elements); } @@ -2537,104 +2602,98 @@ XlaOp GetTupleElement(const XlaOp& tuple_data, int64 index) { } XlaOp Eq(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return lhs.builder()->Eq(lhs, rhs, broadcast_dimensions); } XlaOp Ne(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return lhs.builder()->Ne(lhs, rhs, broadcast_dimensions); } XlaOp Ge(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return lhs.builder()->Ge(lhs, rhs, broadcast_dimensions); } XlaOp Gt(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return lhs.builder()->Gt(lhs, rhs, broadcast_dimensions); } XlaOp Lt(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return lhs.builder()->Lt(lhs, rhs, broadcast_dimensions); } XlaOp Le(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return lhs.builder()->Le(lhs, rhs, broadcast_dimensions); } XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs, - const PrecisionConfigProto* precision_config_proto) { - return lhs.builder()->Dot(lhs, rhs, precision_config_proto); + const PrecisionConfig* precision_config) { + return lhs.builder()->Dot(lhs, rhs, precision_config); } XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs, const DotDimensionNumbers& dimension_numbers, - const PrecisionConfigProto* precision_config_proto) { + const PrecisionConfig* precision_config) { return lhs.builder()->DotGeneral(lhs, rhs, dimension_numbers, - precision_config_proto); + precision_config); } XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice window_strides, Padding padding, - int64 feature_group_count, - const PrecisionConfigProto* precision_config_proto) { + absl::Span window_strides, Padding padding, + int64 feature_group_count, const PrecisionConfig* precision_config) { return lhs.builder()->Conv(lhs, rhs, window_strides, padding, - feature_group_count, precision_config_proto); + feature_group_count, precision_config); } -XlaOp ConvWithGeneralPadding( - const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - int64 feature_group_count, - const PrecisionConfigProto* precision_config_proto) { - return lhs.builder()->ConvWithGeneralPadding(lhs, rhs, window_strides, - padding, feature_group_count, - precision_config_proto); +XlaOp ConvWithGeneralPadding(const XlaOp& lhs, const XlaOp& rhs, + absl::Span window_strides, + absl::Span> padding, + int64 feature_group_count, + const PrecisionConfig* precision_config) { + return lhs.builder()->ConvWithGeneralPadding( + lhs, rhs, window_strides, padding, feature_group_count, precision_config); } XlaOp ConvWithGeneralDimensions( - const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice window_strides, Padding padding, - const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count, - const PrecisionConfigProto* precision_config_proto) { + const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, + Padding padding, const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count, const PrecisionConfig* precision_config) { return lhs.builder()->ConvWithGeneralDimensions( lhs, rhs, window_strides, padding, dimension_numbers, feature_group_count, - precision_config_proto); + precision_config); } XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, + absl::Span window_strides, + absl::Span> padding, const ConvolutionDimensionNumbers& dimension_numbers, int64 feature_group_count, - const PrecisionConfigProto* precision_config_proto) { + const PrecisionConfig* precision_config) { return lhs.builder()->ConvGeneral(lhs, rhs, window_strides, padding, dimension_numbers, feature_group_count, - precision_config_proto); + precision_config); } -XlaOp ConvGeneralDilated( - const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - tensorflow::gtl::ArraySlice lhs_dilation, - tensorflow::gtl::ArraySlice rhs_dilation, - const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count, - const PrecisionConfigProto* precision_config_proto) { +XlaOp ConvGeneralDilated(const XlaOp& lhs, const XlaOp& rhs, + absl::Span window_strides, + absl::Span> padding, + absl::Span lhs_dilation, + absl::Span rhs_dilation, + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count, + const PrecisionConfig* precision_config) { return lhs.builder()->ConvGeneralDilated( lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, - dimension_numbers, feature_group_count, precision_config_proto); + dimension_numbers, feature_group_count, precision_config); } XlaOp Fft(const XlaOp& operand, FftType fft_type, - tensorflow::gtl::ArraySlice fft_length) { + absl::Span fft_length) { return operand.builder()->Fft(operand, fft_type, fft_length); } @@ -2648,99 +2707,116 @@ void Outfeed(const XlaOp& operand, const Shape& shape_with_layout, } XlaOp Call(XlaBuilder* builder, const XlaComputation& computation, - tensorflow::gtl::ArraySlice operands) { + absl::Span operands) { return builder->Call(computation, operands); } XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name, - tensorflow::gtl::ArraySlice operands, - const Shape& shape) { - return builder->CustomCall(call_target_name, operands, shape); + absl::Span operands, const Shape& shape, + const string& opaque) { + return builder->CustomCall(call_target_name, operands, shape, opaque, + /*operand_shapes_with_layout=*/absl::nullopt); +} + +XlaOp CustomCallWithLayout(XlaBuilder* builder, const string& call_target_name, + absl::Span operands, const Shape& shape, + absl::Span operand_shapes_with_layout, + const string& opaque) { + return builder->CustomCall(call_target_name, operands, shape, opaque, + operand_shapes_with_layout); } XlaOp Complex(const XlaOp& real, const XlaOp& imag, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return real.builder()->Complex(real, imag, broadcast_dimensions); } XlaOp Conj(const XlaOp& operand) { return operand.builder()->Conj(operand); } XlaOp Add(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return lhs.builder()->Add(lhs, rhs, broadcast_dimensions); } XlaOp Sub(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return lhs.builder()->Sub(lhs, rhs, broadcast_dimensions); } XlaOp Mul(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return lhs.builder()->Mul(lhs, rhs, broadcast_dimensions); } XlaOp Div(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return lhs.builder()->Div(lhs, rhs, broadcast_dimensions); } XlaOp Rem(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return lhs.builder()->Rem(lhs, rhs, broadcast_dimensions); } XlaOp Max(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return lhs.builder()->Max(lhs, rhs, broadcast_dimensions); } XlaOp Min(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return lhs.builder()->Min(lhs, rhs, broadcast_dimensions); } XlaOp And(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return lhs.builder()->And(lhs, rhs, broadcast_dimensions); } XlaOp Or(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return lhs.builder()->Or(lhs, rhs, broadcast_dimensions); } XlaOp Xor(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return lhs.builder()->Xor(lhs, rhs, broadcast_dimensions); } XlaOp Not(const XlaOp& operand) { return operand.builder()->Not(operand); } XlaOp ShiftLeft(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return lhs.builder()->ShiftLeft(lhs, rhs, broadcast_dimensions); } -XlaOp ShiftRightArithmetic( - const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { +XlaOp ShiftRightArithmetic(const XlaOp& lhs, const XlaOp& rhs, + absl::Span broadcast_dimensions) { return lhs.builder()->ShiftRightArithmetic(lhs, rhs, broadcast_dimensions); } -XlaOp ShiftRightLogical( - const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { +XlaOp ShiftRightLogical(const XlaOp& lhs, const XlaOp& rhs, + absl::Span broadcast_dimensions) { return lhs.builder()->ShiftRightLogical(lhs, rhs, broadcast_dimensions); } XlaOp Reduce(const XlaOp& operand, const XlaOp& init_value, const XlaComputation& computation, - tensorflow::gtl::ArraySlice dimensions_to_reduce) { + absl::Span dimensions_to_reduce) { return operand.builder()->Reduce(operand, init_value, computation, dimensions_to_reduce); } +// Reduces several arrays simultaneously among the provided dimensions, given +// "computation" as a reduction operator. +XlaOp Reduce(XlaBuilder* builder, absl::Span operands, + absl::Span init_values, + const XlaComputation& computation, + absl::Span dimensions_to_reduce) { + return builder->Reduce(operands, init_values, computation, + dimensions_to_reduce); +} + XlaOp ReduceAll(const XlaOp& operand, const XlaOp& init_value, const XlaComputation& computation) { return operand.builder()->ReduceAll(operand, init_value, computation); @@ -2748,9 +2824,8 @@ XlaOp ReduceAll(const XlaOp& operand, const XlaOp& init_value, XlaOp ReduceWindow(const XlaOp& operand, const XlaOp& init_value, const XlaComputation& computation, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - Padding padding) { + absl::Span window_dimensions, + absl::Span window_strides, Padding padding) { return operand.builder()->ReduceWindow(operand, init_value, computation, window_dimensions, window_strides, padding); @@ -2759,22 +2834,23 @@ XlaOp ReduceWindow(const XlaOp& operand, const XlaOp& init_value, XlaOp ReduceWindowWithGeneralPadding( const XlaOp& operand, const XlaOp& init_value, const XlaComputation& computation, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding) { + absl::Span window_dimensions, + absl::Span window_strides, + absl::Span base_dilations, + absl::Span window_dilations, + absl::Span> padding) { return operand.builder()->ReduceWindowWithGeneralPadding( operand, init_value, computation, window_dimensions, window_strides, - padding); + base_dilations, window_dilations, padding); } -XlaOp CrossReplicaSum( - const XlaOp& operand, - tensorflow::gtl::ArraySlice replica_groups) { +XlaOp CrossReplicaSum(const XlaOp& operand, + absl::Span replica_groups) { return operand.builder()->CrossReplicaSum(operand, replica_groups); } XlaOp CrossReplicaSum(const XlaOp& operand, const XlaComputation& computation, - tensorflow::gtl::ArraySlice replica_groups, + absl::Span replica_groups, const absl::optional& channel_id) { return operand.builder()->CrossReplicaSum(operand, computation, replica_groups, channel_id); @@ -2787,11 +2863,17 @@ XlaOp AllToAll(const XlaOp& operand, int64 split_dimension, split_count, replica_groups); } +XlaOp CollectivePermute( + const XlaOp& operand, + const std::vector>& source_target_pairs) { + return operand.builder()->CollectivePermute(operand, source_target_pairs); +} + XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - Padding padding, const XlaOp& source, - const XlaOp& init_value, const XlaComputation& scatter) { + absl::Span window_dimensions, + absl::Span window_strides, Padding padding, + const XlaOp& source, const XlaOp& init_value, + const XlaComputation& scatter) { return operand.builder()->SelectAndScatter(operand, select, window_dimensions, window_strides, padding, source, init_value, scatter); @@ -2799,11 +2881,10 @@ XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select, XlaOp SelectAndScatterWithGeneralPadding( const XlaOp& operand, const XlaComputation& select, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - const XlaOp& source, const XlaOp& init_value, - const XlaComputation& scatter) { + absl::Span window_dimensions, + absl::Span window_strides, + absl::Span> padding, const XlaOp& source, + const XlaOp& init_value, const XlaComputation& scatter) { return operand.builder()->SelectAndScatterWithGeneralPadding( operand, select, window_dimensions, window_strides, padding, source, init_value, scatter); @@ -2812,7 +2893,7 @@ XlaOp SelectAndScatterWithGeneralPadding( XlaOp Abs(const XlaOp& operand) { return operand.builder()->Abs(operand); } XlaOp Atan2(const XlaOp& y, const XlaOp& x, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return y.builder()->Atan2(y, x, broadcast_dimensions); } @@ -2845,7 +2926,7 @@ XlaOp Real(const XlaOp& operand) { return operand.builder()->Real(operand); } XlaOp Imag(const XlaOp& operand) { return operand.builder()->Imag(operand); } XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return lhs.builder()->Pow(lhs, rhs, broadcast_dimensions); } @@ -2863,27 +2944,25 @@ XlaOp BitcastConvertType(const XlaOp& operand, PrimitiveType new_element_type) { XlaOp Neg(const XlaOp& operand) { return operand.builder()->Neg(operand); } -XlaOp Transpose(const XlaOp& operand, - tensorflow::gtl::ArraySlice permutation) { +XlaOp Transpose(const XlaOp& operand, absl::Span permutation) { return operand.builder()->Transpose(operand, permutation); } -XlaOp Rev(const XlaOp& operand, tensorflow::gtl::ArraySlice dimensions) { +XlaOp Rev(const XlaOp& operand, absl::Span dimensions) { return operand.builder()->Rev(operand, dimensions); } -XlaOp Sort(XlaOp keys, absl::optional values, int64 dimension) { - return keys.builder()->Sort(keys, std::move(values), dimension); +XlaOp Sort(const XlaOp& keys, absl::Span values, int64 dimension) { + return keys.builder()->Sort(keys, values, dimension); } XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max) { return min.builder()->Clamp(min, operand, max); } -XlaOp Map(XlaBuilder* builder, tensorflow::gtl::ArraySlice operands, - const XlaComputation& computation, - tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice static_operands) { +XlaOp Map(XlaBuilder* builder, absl::Span operands, + const XlaComputation& computation, absl::Span dimensions, + absl::Span static_operands) { return builder->Map(operands, computation, dimensions, static_operands); } @@ -2917,7 +2996,7 @@ XlaOp ReducePrecision(const XlaOp& operand, const int exponent_bits, XlaOp Gather(const XlaOp& input, const XlaOp& start_indices, const GatherDimensionNumbers& dimension_numbers, - tensorflow::gtl::ArraySlice slice_sizes) { + absl::Span slice_sizes) { return input.builder()->Gather(input, start_indices, dimension_numbers, slice_sizes); } @@ -2973,7 +3052,7 @@ XlaOp OutfeedWithToken(const XlaOp& operand, const XlaOp& token, XlaOp CreateToken(XlaBuilder* builder) { return builder->CreateToken(); } -XlaOp AfterAll(XlaBuilder* builder, tensorflow::gtl::ArraySlice tokens) { +XlaOp AfterAll(XlaBuilder* builder, absl::Span tokens) { return builder->AfterAll(tokens); } @@ -3000,11 +3079,12 @@ XlaOp BatchNormGrad(const XlaOp& operand, const XlaOp& scale, grad_output, epsilon, feature_index); } -XlaOp IotaGen(XlaBuilder* builder, PrimitiveType type, int64 size) { - HloInstructionProto instr; - *instr.mutable_shape() = ShapeUtil::MakeShape(type, {size}); - return builder->ReportErrorOrReturn( - builder->AddInstruction(std::move(instr), HloOpcode::kIota)); +XlaOp Iota(XlaBuilder* builder, PrimitiveType type, int64 size) { + return builder->Iota(type, size); +} + +XlaOp Iota(XlaBuilder* builder, const Shape& shape, int64 iota_dimension) { + return builder->Iota(shape, iota_dimension); } } // namespace xla diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index baa2ae51847e8da1360c607d361fba0463c320ad..5747661c34b411bbf22575f9c1d9fe09aa32911f 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -21,7 +21,10 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/client/padding.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/literal.h" @@ -33,8 +36,6 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/stacktrace.h" #include "tensorflow/core/platform/types.h" @@ -294,7 +295,7 @@ class XlaBuilder { template XlaOp ConstantR0(NativeT value); template - XlaOp ConstantR1(tensorflow::gtl::ArraySlice values); + XlaOp ConstantR1(absl::Span values); XlaOp ConstantR1(const tensorflow::core::Bitmap& values); template XlaOp ConstantR2( @@ -336,7 +337,7 @@ class XlaBuilder { // // output[i0, ..., iN, j0, ..., jM] = operand[j0, ..., jM] XlaOp Broadcast(const XlaOp& operand, - tensorflow::gtl::ArraySlice broadcast_sizes); + absl::Span broadcast_sizes); // Performs in-dimension-style broadcast. // @@ -355,9 +356,8 @@ class XlaBuilder { // will generate output // [1 , 1] // [2 , 2] - XlaOp BroadcastInDim( - const XlaOp& operand, const Shape& shape, - const tensorflow::gtl::ArraySlice broadcast_dimensions); + XlaOp BroadcastInDim(const XlaOp& operand, const Shape& shape, + const absl::Span broadcast_dimensions); // Enqueues a pad operation onto the computation that pads the given value on // the edges as well as between the elements of the input. padding_config @@ -370,15 +370,13 @@ class XlaBuilder { // given, followed by reshaping it into the shape with the given dimension // sizes (also major to minor). Conceptually, this is a limited form of // "shape casting". - XlaOp Reshape(const XlaOp& operand, - tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice new_sizes); + XlaOp Reshape(const XlaOp& operand, absl::Span dimensions, + absl::Span new_sizes); // Enqueues an operation onto the computation that collapses the operand, from // first to last dimension (C order), then reshapes it to the given dimension // sizes. Conceptually, this is a limited form of "shape casting". - XlaOp Reshape(const XlaOp& operand, - tensorflow::gtl::ArraySlice new_sizes); + XlaOp Reshape(const XlaOp& operand, absl::Span new_sizes); // Wrapper for Reshape. // Enqueues an operation to collapse the provided dimensions; e.g. an @@ -398,8 +396,7 @@ class XlaBuilder { // // This could potentially cause data to be moved -- it provides a more // structured form of reshaping than an arbitrary Reshape operation. - XlaOp Collapse(const XlaOp& operand, - tensorflow::gtl::ArraySlice dimensions); + XlaOp Collapse(const XlaOp& operand, absl::Span dimensions); // Enqueues a slice operation onto the computation that slices the operand // from the start indices to the limit indices; e.g. @@ -412,10 +409,9 @@ class XlaBuilder { // Note that "limit" means up-to-but-not-including; i.e. [start, limit) in 1D // range notation. // The strides parameter determines the stride over the slice - XlaOp Slice(const XlaOp& operand, - tensorflow::gtl::ArraySlice start_indices, - tensorflow::gtl::ArraySlice limit_indices, - tensorflow::gtl::ArraySlice strides); + XlaOp Slice(const XlaOp& operand, absl::Span start_indices, + absl::Span limit_indices, + absl::Span strides); // Enqueues a slice operation in a given dimension, taking all other // dimensions as they are; e.g. if dimno is 1 from start_index 2 to @@ -436,7 +432,7 @@ class XlaBuilder { // Slice index calculations are computed modulo input dimension sizes to // prevent dynamic start indices from generating out-of-bound array accesses. XlaOp DynamicSlice(const XlaOp& operand, const XlaOp& start_indices, - tensorflow::gtl::ArraySlice slice_sizes); + absl::Span slice_sizes); // Enqueues a dynamic update slice operation onto the computation, which // updates a slice of 'operand' with 'update' at dynamic 'start_indices'. @@ -459,8 +455,7 @@ class XlaBuilder { // Enqueues a concatenate instruction onto the computation. 'operands' must // have >= 1 entry. - XlaOp ConcatInDim(tensorflow::gtl::ArraySlice operands, - int64 dimension); + XlaOp ConcatInDim(absl::Span operands, int64 dimension); // Enqueue a tracing operation onto the computation; the computation will emit // a logging message with the operand. @@ -471,96 +466,93 @@ class XlaBuilder { XlaOp Select(const XlaOp& pred, const XlaOp& on_true, const XlaOp& on_false); // Enqueues a tuple-creation instruction onto the computation. - XlaOp Tuple(tensorflow::gtl::ArraySlice elements); + XlaOp Tuple(absl::Span elements); // Enqueues a tuple-element-get instruction onto the computation. XlaOp GetTupleElement(const XlaOp& tuple_data, int64 index); // Enqueues an equal-to comparison instruction onto the computation. XlaOp Eq(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues a not-equal comparison instruction onto the computation. XlaOp Ne(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues a greater-or-equal comparison instruction onto the computation. XlaOp Ge(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues a greater-than comparison instruction onto the computation. XlaOp Gt(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues a less-than comparison instruction onto the computation. XlaOp Lt(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues a less-or-equal comparison instruction onto the computation. XlaOp Le(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues a dot instruction onto the computation. XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs, - const PrecisionConfigProto* precision_config_proto = nullptr); + const PrecisionConfig* precision_config = nullptr); // Enqueues a general dot instruction onto the computation. - XlaOp DotGeneral( - const XlaOp& lhs, const XlaOp& rhs, - const DotDimensionNumbers& dimension_numbers, - const PrecisionConfigProto* precision_config_proto = nullptr); + XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs, + const DotDimensionNumbers& dimension_numbers, + const PrecisionConfig* precision_config = nullptr); // Enqueues a convolution instruction onto the computation, which uses the // default convolution dimension numbers. XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice window_strides, Padding padding, + absl::Span window_strides, Padding padding, int64 feature_group_count = 1, - const PrecisionConfigProto* precision_config_proto = nullptr); + const PrecisionConfig* precision_config = nullptr); // Enqueues a convolution instruction onto the computation, with the caller // provided padding configuration in the format returned by MakePadding(). XlaOp ConvWithGeneralPadding( const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, + absl::Span window_strides, + absl::Span> padding, int64 feature_group_count = 1, - const PrecisionConfigProto* precision_config_proto = nullptr); + const PrecisionConfig* precision_config = nullptr); // Enqueues a convolution instruction onto the computation, with the caller // provided dimension numbers configuration. XlaOp ConvWithGeneralDimensions( const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice window_strides, Padding padding, + absl::Span window_strides, Padding padding, const ConvolutionDimensionNumbers& dimension_numbers, int64 feature_group_count = 1, - const PrecisionConfigProto* precision_config_proto = nullptr); + const PrecisionConfig* precision_config = nullptr); // Enqueues a convolution instruction onto the computation, with the caller // provided padding configuration as well as the dimension numbers. - XlaOp ConvGeneral( - const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count = 1, - const PrecisionConfigProto* precision_config_proto = nullptr); + XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs, + absl::Span window_strides, + absl::Span> padding, + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count = 1, + const PrecisionConfig* precision_config = nullptr); // Enqueues a convolution instruction onto the computation, with the caller // provided padding configuration, dilation factors and dimension numbers. - XlaOp ConvGeneralDilated( - const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - tensorflow::gtl::ArraySlice lhs_dilation, - tensorflow::gtl::ArraySlice rhs_dilation, - const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count = 1, - const PrecisionConfigProto* precision_config_proto = nullptr); + XlaOp ConvGeneralDilated(const XlaOp& lhs, const XlaOp& rhs, + absl::Span window_strides, + absl::Span> padding, + absl::Span lhs_dilation, + absl::Span rhs_dilation, + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count = 1, + const PrecisionConfig* precision_config = nullptr); // Enqueues an FFT instruction onto the computation, of the given type and // with the given FFT length. XlaOp Fft(const XlaOp& operand, FftType fft_type, - tensorflow::gtl::ArraySlice fft_length); + absl::Span fft_length); // Enqueues an infeed instruction onto the computation, which writes data of // the given shape to the infeed buffer of the device. @@ -582,15 +574,13 @@ class XlaBuilder { // Enqueues a call instruction onto the computation. XlaOp Call(const XlaComputation& computation, - tensorflow::gtl::ArraySlice operands); + absl::Span operands); // Enqueues a custom call instruction onto the computation. - // During code generation, a call instruction is emitted which targets a - // symbol with the name |call_target_name|. The |operands| are passed to the - // call instruction. |shape| is the resultant shape. - XlaOp CustomCall(const string& call_target_name, - tensorflow::gtl::ArraySlice operands, - const Shape& shape); + XlaOp CustomCall( + const string& call_target_name, absl::Span operands, + const Shape& shape_with_layout, const string& opaque, + absl::optional> operand_shapes_with_layout); // The following methods enqueue element-wise binary arithmetic operations // onto the computation. The shapes of the operands have to match unless one @@ -599,65 +589,70 @@ class XlaBuilder { // Enqueues a complex compose instruction onto the computation. XlaOp Complex(const XlaOp& real, const XlaOp& imag, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues a complex conjugate instruction onto the computation. XlaOp Conj(const XlaOp& operand); // Enqueues an add instruction onto the computation. XlaOp Add(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues a subtract instruction onto the computation. XlaOp Sub(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues a multiply instruction onto the computation. XlaOp Mul(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues a divide instruction onto the computation. XlaOp Div(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues a remainder instruction onto the computation. XlaOp Rem(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues a max instruction onto the computation. XlaOp Max(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues a min instruction onto the computation. XlaOp Min(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Element-wise logical operators XlaOp And(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); XlaOp Or(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); XlaOp Xor(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); XlaOp Not(const XlaOp& operand); XlaOp ShiftLeft(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - XlaOp ShiftRightArithmetic( - const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - XlaOp ShiftRightLogical( - const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); + XlaOp ShiftRightArithmetic(const XlaOp& lhs, const XlaOp& rhs, + absl::Span broadcast_dimensions = {}); + XlaOp ShiftRightLogical(const XlaOp& lhs, const XlaOp& rhs, + absl::Span broadcast_dimensions = {}); // Reduces an array among the provided dimensions, given "computation" as a // reduction operator. XlaOp Reduce(const XlaOp& operand, const XlaOp& init_value, const XlaComputation& computation, - tensorflow::gtl::ArraySlice dimensions_to_reduce); + absl::Span dimensions_to_reduce); + + // Reduces several arrays simultaneously among the provided dimensions, given + // "computation" as a reduction operator. + XlaOp Reduce(absl::Span operands, + absl::Span init_values, + const XlaComputation& computation, + absl::Span dimensions_to_reduce); // Convenience wrapper around the above that reduces all the dimensions in the // operand shape. @@ -667,25 +662,25 @@ class XlaBuilder { // Enqueues a windowed reduce instruction onto the computation. XlaOp ReduceWindow(const XlaOp& operand, const XlaOp& init_value, const XlaComputation& computation, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - Padding padding); + absl::Span window_dimensions, + absl::Span window_strides, Padding padding); // As ReduceWindow(), but the padding is given in the format // returned by MakePadding(). XlaOp ReduceWindowWithGeneralPadding( const XlaOp& operand, const XlaOp& init_value, const XlaComputation& computation, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding); + absl::Span window_dimensions, + absl::Span window_strides, + absl::Span base_dilations, + absl::Span window_dilations, + absl::Span> padding); // Returns the sum of the operand value within each subgroup of replicas. All // replicas supply one input to the sum and all replicas receive the resulting // sum for each subgroup. - XlaOp CrossReplicaSum( - const XlaOp& operand, - tensorflow::gtl::ArraySlice replica_groups = {}); + XlaOp CrossReplicaSum(const XlaOp& operand, + absl::Span replica_groups = {}); // Enqueues an operation that do an AllReduce of the operand cross cores. Here // AllReduce means doing a reduction on the input operand cross cores and then @@ -704,10 +699,10 @@ class XlaBuilder { // the same channel_id, they will be 'Allreduce'd. If empty, Allreduce will // not be applied cross modules. // - // TODO(b/79737069): Rename this to AllReduce when it's ready to use. + // TODO(b/117564385): Rename this to AllReduce when it's ready to use. XlaOp CrossReplicaSum( const XlaOp& operand, const XlaComputation& computation, - tensorflow::gtl::ArraySlice replica_groups = {}, + absl::Span replica_groups = {}, const absl::optional& channel_id = absl::nullopt); // Enqueues an operation that do an Alltoall of the operand cross cores. @@ -715,11 +710,17 @@ class XlaBuilder { int64 concat_dimension, int64 split_count, const std::vector& replica_groups); + // Enqueues an operation that do an CollectivePermute of the operand cross + // cores. + XlaOp CollectivePermute( + const XlaOp& operand, + const std::vector>& source_target_pairs); + // Enqueues an operation that scatters the `source` array to the selected // indices of each window. XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, + absl::Span window_dimensions, + absl::Span window_strides, Padding padding, const XlaOp& source, const XlaOp& init_value, const XlaComputation& scatter); @@ -728,18 +729,17 @@ class XlaBuilder { // returned by MakePadding(). XlaOp SelectAndScatterWithGeneralPadding( const XlaOp& operand, const XlaComputation& select, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - const XlaOp& source, const XlaOp& init_value, - const XlaComputation& scatter); + absl::Span window_dimensions, + absl::Span window_strides, + absl::Span> padding, const XlaOp& source, + const XlaOp& init_value, const XlaComputation& scatter); // Enqueues an abs instruction onto the computation. XlaOp Abs(const XlaOp& operand); // Enqueues a atan2 instruction onto the computation. XlaOp Atan2(const XlaOp& y, const XlaOp& x, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues an exp instruction onto the computation. XlaOp Exp(const XlaOp& operand); @@ -786,7 +786,7 @@ class XlaBuilder { // Enqueues a lhs^rhs computation onto the computation. XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues an operator that tests if the operand's values are finite, i.e., // not Inf or NaN. Defined only for floating-point types. Returns an array of @@ -794,6 +794,12 @@ class XlaBuilder { // entry was NaN. XlaOp IsFinite(const XlaOp& operand); + // Enqueues an iota operation onto the computation. + XlaOp Iota(const Shape& shape, int64 iota_dimension); + + // Enqueues a rank-1 iota operation onto the computation. + XlaOp Iota(PrimitiveType type, int64 size); + // Enqueues a convert instruction onto the computation that changes the // element type of the operand array to primitive_type. XlaOp ConvertElementType(const XlaOp& operand, @@ -810,14 +816,12 @@ class XlaBuilder { XlaOp Neg(const XlaOp& operand); // Enqueues a transpose instruction onto the computation. - XlaOp Transpose(const XlaOp& operand, - tensorflow::gtl::ArraySlice permutation); + XlaOp Transpose(const XlaOp& operand, absl::Span permutation); // Enqueues a reverse instruction onto the computation. The order of the // elements in the given dimensions is reversed (i.e., the element at index i // is moved to index dimension_size - 1 - i). - XlaOp Rev(const XlaOp& operand, - tensorflow::gtl::ArraySlice dimensions); + XlaOp Rev(const XlaOp& operand, absl::Span dimensions); // Enqueues a sort (as increasing order) instruction onto the computation. // If only keys are provided: @@ -830,22 +834,21 @@ class XlaBuilder { // the last dimension is chosen by default. // // If both keys and values are provided: - // * The keys and the values must tensors with the same dimensions. The + // * The keys and all values must be tensors with the same dimensions. The // element types of the tensors may be different. // * The result is a tuple that consists of a sorted tensor of keys (along the - // provided dimension, as above) as the first element, and a tensor with their - // corresponding values as the second element. - XlaOp Sort(XlaOp keys, absl::optional values = absl::nullopt, + // provided dimension, as above) as the first element, and tensors with their + // corresponding values as the other elements. + XlaOp Sort(const XlaOp& keys, absl::Span values = {}, int64 dimension = -1); // Enqueues a clamp instruction onto the computation. XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max); // Enqueues a map instruction onto the computation. - XlaOp Map(tensorflow::gtl::ArraySlice operands, - const XlaComputation& computation, - tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice static_operands = {}); + XlaOp Map(absl::Span operands, const XlaComputation& computation, + absl::Span dimensions, + absl::Span static_operands = {}); // Enqueues a N(mu, sigma) random number generation instruction onto the // computation. @@ -872,7 +875,7 @@ class XlaBuilder { // Enqueues a Gather node onto the computation. XlaOp Gather(const XlaOp& input, const XlaOp& start_indices, const GatherDimensionNumbers& dimension_numbers, - tensorflow::gtl::ArraySlice slice_sizes); + absl::Span slice_sizes); // Enqueues a Scatter node onto the computation. XlaOp Scatter(const XlaOp& input, const XlaOp& scatter_indices, @@ -900,7 +903,7 @@ class XlaBuilder { // Enqueues an AfterAll operation with no operands producing a token-shaped // value. - XlaOp AfterAll(tensorflow::gtl::ArraySlice tokens); + XlaOp AfterAll(absl::Span tokens); // Enqueues a Recv node onto the computation. The data comes from a Send // instruction that shares the same channel handle and its shape must @@ -947,14 +950,15 @@ class XlaBuilder { const XlaOp& grad_output, float epsilon, int64 feature_index); - StatusOr AddInstruction( - HloInstructionProto&& instr, HloOpcode opcode, - tensorflow::gtl::ArraySlice operands = {}); + StatusOr AddInstruction(HloInstructionProto&& instr, HloOpcode opcode, + absl::Span operands = {}); void AddCalledComputation(const XlaComputation& computation, HloInstructionProto* instr); StatusOr LookUpInstruction(const XlaOp& op) const; + StatusOr LookUpInstructionByHandle( + int64 handle) const; // Internal helper method that does the building for an arbitrary unary op. XlaOp UnaryOp(HloOpcode unop, const XlaOp& operand); @@ -963,19 +967,17 @@ class XlaBuilder { // broadcast_dimensions specifies which dimensions to use for broadcasting // when the operation is between tensors of different ranks. XlaOp BinaryOp(HloOpcode binop, const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions); + absl::Span broadcast_dimensions); // Internal helper method that does the building for an arbitrary ternary op. XlaOp TernaryOp(HloOpcode triop, const XlaOp& lhs, const XlaOp& rhs, const XlaOp& ehs); XlaOp RngOp(RandomDistribution distribution, - tensorflow::gtl::ArraySlice parameters, - const Shape& shape); + absl::Span parameters, const Shape& shape); - StatusOr InDimBroadcast( - const Shape& shape, const XlaOp& operand, - tensorflow::gtl::ArraySlice broadcast_dimensions); + StatusOr InDimBroadcast(const Shape& shape, const XlaOp& operand, + absl::Span broadcast_dimensions); // Internal helper method that creates a sequence of instructions that // performs an explicit broadcast of the operand to the target shape. @@ -991,7 +993,7 @@ class XlaBuilder { // Returns shapes for the operands. StatusOr> GetOperandShapes( - tensorflow::gtl::ArraySlice operands) const; + absl::Span operands) const; // A visitor which checks whether an operation is a compile-time constant, // meaning that it doesn't depend on any parameters, or on any stateful @@ -1008,12 +1010,11 @@ class XlaBuilder { // Helper function for creating a Window proto from user-supplied data. // Returns error if the user-supplied data was invalid. - StatusOr MakeWindow( - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - tensorflow::gtl::ArraySlice lhs_dilation, - tensorflow::gtl::ArraySlice rhs_dilation) const; + StatusOr MakeWindow(absl::Span window_dimensions, + absl::Span window_strides, + absl::Span> padding, + absl::Span lhs_dilation, + absl::Span rhs_dilation) const; string name_; // Name to use for the built computation. @@ -1027,13 +1028,17 @@ class XlaBuilder { // The instructions of this computation. std::vector instructions_; + // A map from XlaOp::Handle to the index in the instructions_ vector where the + // instruction is held. + absl::flat_hash_map handle_to_index_; + // The embedded computations used by this computation. Each computation was // the entry computation of some XlaComputation, the key is the unique id of // that XlaComputation. std::map embedded_; // The unique parameter numbers. - tensorflow::gtl::FlatSet parameter_numbers_; + absl::flat_hash_set parameter_numbers_; // The metadata to attach to each op. This is structured as a "modal"-like // operation, in order to simplify client code (and not sprinkle this metadata @@ -1057,7 +1062,7 @@ class XlaBuilder { friend XlaOp ConstantR0(XlaBuilder* builder, NativeT value); template friend XlaOp ConstantR1(XlaBuilder* builder, - tensorflow::gtl::ArraySlice values); + absl::Span values); friend XlaOp ConstantR1(XlaBuilder* builder, const tensorflow::core::Bitmap& values); template @@ -1097,185 +1102,187 @@ class XlaBuilder { friend XlaOp ConstantR1(XlaBuilder* builder, int64 length, NativeT value); friend XlaOp Broadcast(const XlaOp& operand, - tensorflow::gtl::ArraySlice broadcast_sizes); + absl::Span broadcast_sizes); friend XlaOp BroadcastInDim( const XlaOp& operand, const Shape& shape, - const tensorflow::gtl::ArraySlice broadcast_dimensions); + const absl::Span broadcast_dimensions); friend XlaOp Pad(const XlaOp& operand, const XlaOp& padding_value, const PaddingConfig& padding_config); - friend XlaOp Reshape(const XlaOp& operand, - tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice new_sizes); + friend XlaOp Reshape(const XlaOp& operand, absl::Span dimensions, + absl::Span new_sizes); - friend XlaOp Reshape(const XlaOp& operand, - tensorflow::gtl::ArraySlice new_sizes); + friend XlaOp Reshape(const XlaOp& operand, absl::Span new_sizes); friend XlaOp Collapse(const XlaOp& operand, - tensorflow::gtl::ArraySlice dimensions); + absl::Span dimensions); friend XlaOp Slice(const XlaOp& operand, - tensorflow::gtl::ArraySlice start_indices, - tensorflow::gtl::ArraySlice limit_indices, - tensorflow::gtl::ArraySlice strides); + absl::Span start_indices, + absl::Span limit_indices, + absl::Span strides); friend XlaOp SliceInDim(const XlaOp& operand, int64 start_index, int64 limit_index, int64 stride, int64 dimno); friend XlaOp DynamicSlice(const XlaOp& operand, const XlaOp& start_indices, - tensorflow::gtl::ArraySlice slice_sizes); + absl::Span slice_sizes); friend XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, const XlaOp& start_indices); friend XlaOp ConcatInDim(XlaBuilder* builder, - tensorflow::gtl::ArraySlice operands, - int64 dimension); + absl::Span operands, int64 dimension); friend void Trace(const string& tag, const XlaOp& operand); friend XlaOp Select(const XlaOp& pred, const XlaOp& on_true, const XlaOp& on_false); - friend XlaOp Tuple(XlaBuilder* builder, - tensorflow::gtl::ArraySlice elements); + friend XlaOp Tuple(XlaBuilder* builder, absl::Span elements); friend XlaOp GetTupleElement(const XlaOp& tuple_data, int64 index); friend XlaOp Eq(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions); + absl::Span broadcast_dimensions); friend XlaOp Ne(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions); + absl::Span broadcast_dimensions); friend XlaOp Ge(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions); + absl::Span broadcast_dimensions); friend XlaOp Gt(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions); + absl::Span broadcast_dimensions); friend XlaOp Lt(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions); + absl::Span broadcast_dimensions); friend XlaOp Le(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions); + absl::Span broadcast_dimensions); friend XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs, - const PrecisionConfigProto* precision_config_proto); + const PrecisionConfig* precision_config); friend XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs, const DotDimensionNumbers& dimension_number, - const PrecisionConfigProto* precision_config_proto); + const PrecisionConfig* precision_config); friend XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice window_strides, - Padding padding, int64 feature_group_count, - const PrecisionConfigProto* precision_config_proto); + absl::Span window_strides, Padding padding, + int64 feature_group_count, + const PrecisionConfig* precision_config); friend XlaOp ConvWithGeneralPadding( const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - int64 feature_group_count, - const PrecisionConfigProto* precision_config_proto); + absl::Span window_strides, + absl::Span> padding, + int64 feature_group_count, const PrecisionConfig* precision_config); friend XlaOp ConvWithGeneralDimensions( const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice window_strides, Padding padding, + absl::Span window_strides, Padding padding, const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count, - const PrecisionConfigProto* precision_config_proto); - friend XlaOp ConvGeneral( - const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count, - const PrecisionConfigProto* precision_config_proto); + int64 feature_group_count, const PrecisionConfig* precision_config); + friend XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs, + absl::Span window_strides, + absl::Span> padding, + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count, + const PrecisionConfig* precision_config); friend XlaOp ConvGeneralDilated( const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - tensorflow::gtl::ArraySlice lhs_dilation, - tensorflow::gtl::ArraySlice rhs_dilation, + absl::Span window_strides, + absl::Span> padding, + absl::Span lhs_dilation, + absl::Span rhs_dilation, const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count, - const PrecisionConfigProto* precision_config_proto); + int64 feature_group_count, const PrecisionConfig* precision_config); friend XlaOp Fft(const XlaOp& operand, FftType fft_type, - tensorflow::gtl::ArraySlice fft_length); + absl::Span fft_length); friend XlaOp Infeed(XlaBuilder* builder, const Shape& shape, const string& config); friend void Outfeed(const XlaOp& operand, const Shape& shape_with_layout, const string& outfeed_config); friend XlaOp Call(XlaBuilder* builder, const XlaComputation& computation, - tensorflow::gtl::ArraySlice operands); + absl::Span operands); friend XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name, - tensorflow::gtl::ArraySlice operands, - const Shape& shape); + absl::Span operands, const Shape& shape, + const string& opaque); + friend XlaOp CustomCallWithLayout( + XlaBuilder* builder, const string& call_target_name, + absl::Span operands, const Shape& shape_with_layout, + absl::Span operand_shapes_with_layout, const string& opaque); friend XlaOp Complex(const XlaOp& real, const XlaOp& imag, - tensorflow::gtl::ArraySlice broadcast_dimensions); + absl::Span broadcast_dimensions); friend XlaOp Conj(const XlaOp& operand); friend XlaOp Add(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions); + absl::Span broadcast_dimensions); friend XlaOp Sub(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions); + absl::Span broadcast_dimensions); friend XlaOp Mul(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions); + absl::Span broadcast_dimensions); friend XlaOp Div(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions); + absl::Span broadcast_dimensions); friend XlaOp Rem(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions); + absl::Span broadcast_dimensions); friend XlaOp Max(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions); + absl::Span broadcast_dimensions); friend XlaOp Min(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions); + absl::Span broadcast_dimensions); friend XlaOp And(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions); + absl::Span broadcast_dimensions); friend XlaOp Or(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions); + absl::Span broadcast_dimensions); friend XlaOp Xor(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions); + absl::Span broadcast_dimensions); friend XlaOp Not(const XlaOp& operand); - friend XlaOp ShiftLeft( - const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions); + friend XlaOp ShiftLeft(const XlaOp& lhs, const XlaOp& rhs, + absl::Span broadcast_dimensions); friend XlaOp ShiftRightArithmetic( const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions); - friend XlaOp ShiftRightLogical( - const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions); + absl::Span broadcast_dimensions); + friend XlaOp ShiftRightLogical(const XlaOp& lhs, const XlaOp& rhs, + absl::Span broadcast_dimensions); friend XlaOp Reduce(const XlaOp& operand, const XlaOp& init_value, const XlaComputation& computation, - tensorflow::gtl::ArraySlice dimensions_to_reduce); + absl::Span dimensions_to_reduce); + friend XlaOp Reduce(XlaBuilder* builder, absl::Span operands, + absl::Span init_values, + const XlaComputation& computation, + absl::Span dimensions_to_reduce); friend XlaOp ReduceAll(const XlaOp& operand, const XlaOp& init_value, const XlaComputation& computation); - friend XlaOp ReduceWindow( - const XlaOp& operand, const XlaOp& init_value, - const XlaComputation& computation, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, Padding padding); + friend XlaOp ReduceWindow(const XlaOp& operand, const XlaOp& init_value, + const XlaComputation& computation, + absl::Span window_dimensions, + absl::Span window_strides, + Padding padding); friend XlaOp ReduceWindowWithGeneralPadding( const XlaOp& operand, const XlaOp& init_value, const XlaComputation& computation, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding); - friend XlaOp CrossReplicaSum( - const XlaOp& operand, - tensorflow::gtl::ArraySlice replica_groups); - friend XlaOp CrossReplicaSum( - const XlaOp& operand, const XlaComputation& computation, - tensorflow::gtl::ArraySlice replica_groups, - const absl::optional& channel_id); + absl::Span window_dimensions, + absl::Span window_strides, + absl::Span base_dilations, + absl::Span window_dilations, + absl::Span> padding); + friend XlaOp CrossReplicaSum(const XlaOp& operand, + absl::Span replica_groups); + friend XlaOp CrossReplicaSum(const XlaOp& operand, + const XlaComputation& computation, + absl::Span replica_groups, + const absl::optional& channel_id); friend XlaOp AllToAll(const XlaOp& operand, int64 split_dimension, int64 concat_dimension, int64 split_count, const std::vector& replica_groups); - friend XlaOp SelectAndScatter( - const XlaOp& operand, const XlaComputation& select, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, Padding padding, - const XlaOp& source, const XlaOp& init_value, - const XlaComputation& scatter); + friend XlaOp CollectivePermute( + const XlaOp& operand, + const std::vector>& source_target_pairs); + friend XlaOp SelectAndScatter(const XlaOp& operand, + const XlaComputation& select, + absl::Span window_dimensions, + absl::Span window_strides, + Padding padding, const XlaOp& source, + const XlaOp& init_value, + const XlaComputation& scatter); friend XlaOp SelectAndScatterWithGeneralPadding( const XlaOp& operand, const XlaComputation& select, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - const XlaOp& source, const XlaOp& init_value, - const XlaComputation& scatter); + absl::Span window_dimensions, + absl::Span window_strides, + absl::Span> padding, const XlaOp& source, + const XlaOp& init_value, const XlaComputation& scatter); friend XlaOp Abs(const XlaOp& operand); friend XlaOp Atan2(const XlaOp& y, const XlaOp& x, - tensorflow::gtl::ArraySlice broadcast_dimensions); + absl::Span broadcast_dimensions); friend XlaOp Exp(const XlaOp& operand); friend XlaOp Expm1(const XlaOp& operand); friend XlaOp Floor(const XlaOp& operand); @@ -1291,27 +1298,26 @@ class XlaBuilder { friend XlaOp Real(const XlaOp& operand); friend XlaOp Imag(const XlaOp& operand); friend XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions); + absl::Span broadcast_dimensions); friend XlaOp IsFinite(const XlaOp& operand); - // TODO(b/64798317): Finish CPU & GPU implementation, then replace xla::Iota - // in xla/client/lib/numeric.h with this (renamed to xla::Iota). - friend XlaOp IotaGen(XlaBuilder* builder, PrimitiveType type, int64 size); + friend XlaOp Iota(XlaBuilder* builder, const Shape& shape, + int64 iota_dimension); + friend XlaOp Iota(XlaBuilder* builder, PrimitiveType type, int64 size); friend XlaOp ConvertElementType(const XlaOp& operand, PrimitiveType new_element_type); friend XlaOp BitcastConvertType(const XlaOp& operand, PrimitiveType new_element_type); friend XlaOp Neg(const XlaOp& operand); friend XlaOp Transpose(const XlaOp& operand, - tensorflow::gtl::ArraySlice permutation); - friend XlaOp Rev(const XlaOp& operand, - tensorflow::gtl::ArraySlice dimensions); - friend XlaOp Sort(XlaOp keys, absl::optional values, int64 dimension); + absl::Span permutation); + friend XlaOp Rev(const XlaOp& operand, absl::Span dimensions); + friend XlaOp Sort(const XlaOp& keys, absl::Span values, + int64 dimension); friend XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max); - friend XlaOp Map(XlaBuilder* builder, - tensorflow::gtl::ArraySlice operands, + friend XlaOp Map(XlaBuilder* builder, absl::Span operands, const XlaComputation& computation, - tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice static_operands); + absl::Span dimensions, + absl::Span static_operands); friend XlaOp RngNormal(const XlaOp& mu, const XlaOp& sigma, const Shape& shape); friend XlaOp RngUniform(const XlaOp& a, const XlaOp& b, const Shape& shape); @@ -1325,7 +1331,7 @@ class XlaBuilder { const int mantissa_bits); friend XlaOp Gather(const XlaOp& input, const XlaOp& start_indices, const GatherDimensionNumbers& dimension_numbers, - tensorflow::gtl::ArraySlice slice_sizes); + absl::Span slice_sizes); friend XlaOp Scatter(const XlaOp& input, const XlaOp& scatter_indices, const XlaOp& updates, const XlaComputation& update_computation, @@ -1359,8 +1365,7 @@ class XlaBuilder { const Shape& shape_with_layout, const string& outfeed_config); friend XlaOp CreateToken(XlaBuilder* builder); - friend XlaOp AfterAll(XlaBuilder* builder, - tensorflow::gtl::ArraySlice tokens); + friend XlaOp AfterAll(XlaBuilder* builder, absl::Span tokens); }; // RAII-style object: sets the current sharding assignment in builder on @@ -1424,8 +1429,7 @@ XlaOp ConstantLiteral(XlaBuilder* builder, const LiteralSlice& literal); template XlaOp ConstantR0(XlaBuilder* builder, NativeT value); template -XlaOp ConstantR1(XlaBuilder* builder, - tensorflow::gtl::ArraySlice values); +XlaOp ConstantR1(XlaBuilder* builder, absl::Span values); XlaOp ConstantR1(XlaBuilder* builder, const tensorflow::core::Bitmap& values); template XlaOp ConstantR2(XlaBuilder* builder, @@ -1474,8 +1478,7 @@ XlaOp ConstantR1(XlaBuilder* builder, int64 length, NativeT value); // The new dimensions index into copies of the operand, i.e. // // output[i0, ..., iN, j0, ..., jM] = operand[j0, ..., jM] -XlaOp Broadcast(const XlaOp& operand, - tensorflow::gtl::ArraySlice broadcast_sizes); +XlaOp Broadcast(const XlaOp& operand, absl::Span broadcast_sizes); // Performs in-dimension-style broadcast. // @@ -1494,9 +1497,8 @@ XlaOp Broadcast(const XlaOp& operand, // will generate output // [1 , 1] // [2 , 2] -XlaOp BroadcastInDim( - const XlaOp& operand, const Shape& shape, - const tensorflow::gtl::ArraySlice broadcast_dimensions); +XlaOp BroadcastInDim(const XlaOp& operand, const Shape& shape, + const absl::Span broadcast_dimensions); // Enqueues a pad operation onto the computation that pads the given value on // the edges as well as between the elements of the input. padding_config @@ -1509,15 +1511,13 @@ XlaOp Pad(const XlaOp& operand, const XlaOp& padding_value, // given, followed by reshaping it into the shape with the given dimension // sizes (also major to minor). Conceptually, this is a limited form of // "shape casting". -XlaOp Reshape(const XlaOp& operand, - tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice new_sizes); +XlaOp Reshape(const XlaOp& operand, absl::Span dimensions, + absl::Span new_sizes); // Enqueues an operation onto the computation that collapses the operand, from // first to last dimension (C order), then reshapes it to the given dimension // sizes. Conceptually, this is a limited form of "shape casting". -XlaOp Reshape(const XlaOp& operand, - tensorflow::gtl::ArraySlice new_sizes); +XlaOp Reshape(const XlaOp& operand, absl::Span new_sizes); // Wrapper for Reshape. // Enqueues an operation to collapse the provided dimensions; e.g. an @@ -1537,8 +1537,7 @@ XlaOp Reshape(const XlaOp& operand, // // This could potentially cause data to be moved -- it provides a more // structured form of reshaping than an arbitrary Reshape operation. -XlaOp Collapse(const XlaOp& operand, - tensorflow::gtl::ArraySlice dimensions); +XlaOp Collapse(const XlaOp& operand, absl::Span dimensions); // Enqueues a slice operation onto the computation that slices the operand // from the start indices to the limit indices; e.g. @@ -1551,10 +1550,9 @@ XlaOp Collapse(const XlaOp& operand, // Note that "limit" means up-to-but-not-including; i.e. [start, limit) in 1D // range notation. // The strides parameter determines the stride over the slice -XlaOp Slice(const XlaOp& operand, - tensorflow::gtl::ArraySlice start_indices, - tensorflow::gtl::ArraySlice limit_indices, - tensorflow::gtl::ArraySlice strides); +XlaOp Slice(const XlaOp& operand, absl::Span start_indices, + absl::Span limit_indices, + absl::Span strides); // Enqueues a slice operation in a given dimension, taking all other // dimensions as they are; e.g. if dimno is 1 from start_index 2 to @@ -1575,7 +1573,7 @@ XlaOp SliceInDim(const XlaOp& operand, int64 start_index, int64 limit_index, // Slice index calculations are computed modulo input dimension sizes to // prevent dynamic start indices from generating out-of-bound array accesses. XlaOp DynamicSlice(const XlaOp& operand, const XlaOp& start_indices, - tensorflow::gtl::ArraySlice slice_sizes); + absl::Span slice_sizes); // Enqueues a dynamic update slice operation onto the computation, which // updates a slice of 'operand' with 'update' at dynamic 'start_indices'. @@ -1598,8 +1596,8 @@ XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, // Enqueues a concatenate instruction onto the computation. 'operands' must // have >= 1 entry. -XlaOp ConcatInDim(XlaBuilder* builder, - tensorflow::gtl::ArraySlice operands, int64 dimension); +XlaOp ConcatInDim(XlaBuilder* builder, absl::Span operands, + int64 dimension); // Enqueue a tracing operation onto the computation; the computation will emit // a logging message with the operand. @@ -1610,94 +1608,91 @@ void Trace(const string& tag, const XlaOp& operand); XlaOp Select(const XlaOp& pred, const XlaOp& on_true, const XlaOp& on_false); // Enqueues a tuple-creation instruction onto the computation. -XlaOp Tuple(XlaBuilder* builder, tensorflow::gtl::ArraySlice elements); +XlaOp Tuple(XlaBuilder* builder, absl::Span elements); // Enqueues a tuple-element-get instruction onto the computation. XlaOp GetTupleElement(const XlaOp& tuple_data, int64 index); // Enqueues an equal-to comparison instruction onto the computation. XlaOp Eq(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues a not-equal comparison instruction onto the computation. XlaOp Ne(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues a greater-or-equal comparison instruction onto the computation. XlaOp Ge(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues a greater-than comparison instruction onto the computation. XlaOp Gt(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues a less-than comparison instruction onto the computation. XlaOp Lt(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues a less-or-equal comparison instruction onto the computation. XlaOp Le(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues a dot instruction onto the computation. XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs, - const PrecisionConfigProto* precision_config_proto = nullptr); + const PrecisionConfig* precision_config = nullptr); // Enqueues a general dot instruction onto the computation. XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs, const DotDimensionNumbers& dimension_numbers, - const PrecisionConfigProto* precision_config_proto = nullptr); + const PrecisionConfig* precision_config = nullptr); // Enqueues a convolution instruction onto the computation, which uses the // default convolution dimension numbers. XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice window_strides, Padding padding, + absl::Span window_strides, Padding padding, int64 feature_group_count = 1, - const PrecisionConfigProto* precision_config_proto = nullptr); + const PrecisionConfig* precision_config = nullptr); // Enqueues a convolution instruction onto the computation, with the caller // provided padding configuration in the format returned by MakePadding(). -XlaOp ConvWithGeneralPadding( - const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - int64 feature_group_count = 1, - const PrecisionConfigProto* precision_config_proto = nullptr); +XlaOp ConvWithGeneralPadding(const XlaOp& lhs, const XlaOp& rhs, + absl::Span window_strides, + absl::Span> padding, + int64 feature_group_count = 1, + const PrecisionConfig* precision_config = nullptr); // Enqueues a convolution instruction onto the computation, with the caller // provided dimension numbers configuration. XlaOp ConvWithGeneralDimensions( - const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice window_strides, Padding padding, - const ConvolutionDimensionNumbers& dimension_numbers, + const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, + Padding padding, const ConvolutionDimensionNumbers& dimension_numbers, int64 feature_group_count = 1, - const PrecisionConfigProto* precision_config_proto = nullptr); + const PrecisionConfig* precision_config = nullptr); // Enqueues a convolution instruction onto the computation, with the caller // provided padding configuration as well as the dimension numbers. XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, + absl::Span window_strides, + absl::Span> padding, const ConvolutionDimensionNumbers& dimension_numbers, int64 feature_group_count = 1, - const PrecisionConfigProto* precision_config_proto = nullptr); + const PrecisionConfig* precision_config = nullptr); // Enqueues a convolution instruction onto the computation, with the caller // provided padding configuration, dilation factors and dimension numbers. -XlaOp ConvGeneralDilated( - const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - tensorflow::gtl::ArraySlice lhs_dilation, - tensorflow::gtl::ArraySlice rhs_dilation, - const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count = 1, - const PrecisionConfigProto* precision_config_proto = nullptr); +XlaOp ConvGeneralDilated(const XlaOp& lhs, const XlaOp& rhs, + absl::Span window_strides, + absl::Span> padding, + absl::Span lhs_dilation, + absl::Span rhs_dilation, + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count = 1, + const PrecisionConfig* precision_config = nullptr); // Enqueues an FFT instruction onto the computation, of the given type and // with the given FFT length. XlaOp Fft(const XlaOp& operand, FftType fft_type, - tensorflow::gtl::ArraySlice fft_length); + absl::Span fft_length); // Enqueues an infeed instruction onto the computation, which writes data of // the given shape to the infeed buffer of the device. @@ -1729,15 +1724,30 @@ XlaOp OutfeedWithToken(const XlaOp& operand, const XlaOp& token, // Enqueues a call instruction onto the computation. XlaOp Call(XlaBuilder* builder, const XlaComputation& computation, - tensorflow::gtl::ArraySlice operands); - -// Enqueues a custom call instruction onto the computation. -// During code generation, a call instruction is emitted which targets a -// symbol with the name |call_target_name|. The |operands| are passed to the -// call instruction. |shape| is the resultant shape. + absl::Span operands); + +// Enqueues a custom call instruction onto the computation. A custom call +// invokes code external to XLA. The |operands| are passed to the external code, +// and the external code is expected to produce a result of the given +// |shape|. The exact mechanism is backend-specific. For example, in the CPU +// backend, a call instruction is emitted which targets a symbol with the name +// |call_target_name|. |call_target_name| and |opaque| can arbitrary strings, +// but |call_target_name| should be short as it may be used in labels. |opaque| +// can encode arbitrarily large amounts of information. XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name, - tensorflow::gtl::ArraySlice operands, - const Shape& shape); + absl::Span operands, const Shape& shape, + const string& opaque = ""); + +// Overload which constructs a custom call with fixed layouts. The operands will +// have the layouts specified by |operand_shapes_with_layout| when provided to +// external code, and the external code is expected to produce a result with the +// layout specified by |shape_with_layout|. All shapes in |shape_with_layout| +// and |operand_shapes_with_layout| must have layouts. +XlaOp CustomCallWithLayout(XlaBuilder* builder, const string& call_target_name, + absl::Span operands, + const Shape& shape_with_layout, + absl::Span operand_shapes_with_layout, + const string& opaque = ""); // The following methods enqueue element-wise binary arithmetic operations // onto the computation. The shapes of the operands have to match unless one @@ -1746,65 +1756,70 @@ XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name, // Enqueues a complex compose instruction onto the computation. XlaOp Complex(const XlaOp& real, const XlaOp& imag, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues a complex conjugate instruction onto the computation. XlaOp Conj(const XlaOp& operand); // Enqueues an add instruction onto the computation. XlaOp Add(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues a subtract instruction onto the computation. XlaOp Sub(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues a multiply instruction onto the computation. XlaOp Mul(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues a divide instruction onto the computation. XlaOp Div(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues a remainder instruction onto the computation. XlaOp Rem(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues a max instruction onto the computation. XlaOp Max(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues a min instruction onto the computation. XlaOp Min(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Element-wise logical operators XlaOp And(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); XlaOp Or(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); XlaOp Xor(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); XlaOp Not(const XlaOp& operand); XlaOp ShiftLeft(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); -XlaOp ShiftRightArithmetic( - const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); -XlaOp ShiftRightLogical( - const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); +XlaOp ShiftRightArithmetic(const XlaOp& lhs, const XlaOp& rhs, + absl::Span broadcast_dimensions = {}); +XlaOp ShiftRightLogical(const XlaOp& lhs, const XlaOp& rhs, + absl::Span broadcast_dimensions = {}); // Reduces an array among the provided dimensions, given "computation" as a // reduction operator. XlaOp Reduce(const XlaOp& operand, const XlaOp& init_value, const XlaComputation& computation, - tensorflow::gtl::ArraySlice dimensions_to_reduce); + absl::Span dimensions_to_reduce); + +// Reduces several arrays simultaneously among the provided dimensions, given +// "computation" as a reduction operator. +XlaOp Reduce(XlaBuilder* builder, absl::Span operands, + absl::Span init_values, + const XlaComputation& computation, + absl::Span dimensions_to_reduce); // Convenience wrapper around the above that reduces all the dimensions in the // operand shape. @@ -1814,25 +1829,25 @@ XlaOp ReduceAll(const XlaOp& operand, const XlaOp& init_value, // Enqueues a windowed reduce instruction onto the computation. XlaOp ReduceWindow(const XlaOp& operand, const XlaOp& init_value, const XlaComputation& computation, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - Padding padding); + absl::Span window_dimensions, + absl::Span window_strides, Padding padding); // As ReduceWindow(), but the padding is given in the format // returned by MakePadding(). XlaOp ReduceWindowWithGeneralPadding( const XlaOp& operand, const XlaOp& init_value, const XlaComputation& computation, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding); + absl::Span window_dimensions, + absl::Span window_strides, + absl::Span base_dilations, + absl::Span window_dilations, + absl::Span> padding); // Returns the sum of the operand value within each subgroup of replicas. All // replicas supply one input to the sum and all replicas receive the resulting // sum for each subgroup. -XlaOp CrossReplicaSum( - const XlaOp& operand, - tensorflow::gtl::ArraySlice replica_groups = {}); +XlaOp CrossReplicaSum(const XlaOp& operand, + absl::Span replica_groups = {}); // Enqueues an operation that do an AllReduce of the operand cross cores. Here // AllReduce means doing a reduction on the input operand cross cores and then @@ -1850,10 +1865,10 @@ XlaOp CrossReplicaSum( // same channel_id, they will be 'Allreduce'd. If empty, Allreduce will not be // applied cross modules. // -// TODO(b/79737069): Rename this to AllReduce when it's ready to use. +// TODO(b/117564385): Rename this to AllReduce when it's ready to use. XlaOp CrossReplicaSum( const XlaOp& operand, const XlaComputation& computation, - tensorflow::gtl::ArraySlice replica_groups = {}, + absl::Span replica_groups = {}, const absl::optional& channel_id = absl::nullopt); // Enqueues an operation that do an Alltoall of the operand cross cores. @@ -1861,30 +1876,41 @@ XlaOp AllToAll(const XlaOp& operand, int64 split_dimension, int64 concat_dimension, int64 split_count, const std::vector& replica_groups = {}); +// Enqueues an collective operation that sends and receives data cross replicas. +// +// - `source_target_pair`: a list of (source_replica_id, target_replica_id) +// pairs. For each pair, the operand is sent from source replica to target +// replica. Note that, 1) any two pairs should not have the same target replica +// id, and they should not have the same source replica id; 2) if a replica id +// is not a target in any pair, then the output on that replica is a tensor +// consists of 0(s) with the same shape as the input. +XlaOp CollectivePermute( + const XlaOp& operand, + const std::vector>& source_target_pairs); + // Enqueues an operation that scatters the `source` array to the selected // indices of each window. XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - Padding padding, const XlaOp& source, - const XlaOp& init_value, const XlaComputation& scatter); + absl::Span window_dimensions, + absl::Span window_strides, Padding padding, + const XlaOp& source, const XlaOp& init_value, + const XlaComputation& scatter); // As SelectAndScatter(), but the padding is given in the format // returned by MakePadding(). XlaOp SelectAndScatterWithGeneralPadding( const XlaOp& operand, const XlaComputation& select, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - const XlaOp& source, const XlaOp& init_value, - const XlaComputation& scatter); + absl::Span window_dimensions, + absl::Span window_strides, + absl::Span> padding, const XlaOp& source, + const XlaOp& init_value, const XlaComputation& scatter); // Enqueues an abs instruction onto the computation. XlaOp Abs(const XlaOp& operand); // Enqueues a atan2 instruction onto the computation. XlaOp Atan2(const XlaOp& y, const XlaOp& x, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues an exp instruction onto the computation. XlaOp Exp(const XlaOp& operand); @@ -1931,7 +1957,7 @@ XlaOp Imag(const XlaOp& operand); // Enqueues a lhs^rhs computation onto the computation. XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues an operator that tests if the operand's values are finite, i.e., // not Inf or NaN. Defined only for floating-point types. Returns an array of @@ -1939,6 +1965,12 @@ XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs, // entry was NaN. XlaOp IsFinite(const XlaOp& operand); +// Enqueues an iota operation onto the computation. +XlaOp Iota(XlaBuilder* builder, const Shape& shape, int64 iota_dimension); + +// Enqueues a rank-1 iota operation onto the computation. +XlaOp Iota(XlaBuilder* builder, PrimitiveType type, int64 size); + // Enqueues a convert instruction onto the computation that changes the // element type of the operand array to primitive_type. XlaOp ConvertElementType(const XlaOp& operand, PrimitiveType new_element_type); @@ -1953,13 +1985,12 @@ XlaOp BitcastConvertType(const XlaOp& operand, PrimitiveType new_element_type); XlaOp Neg(const XlaOp& operand); // Enqueues a transpose instruction onto the computation. -XlaOp Transpose(const XlaOp& operand, - tensorflow::gtl::ArraySlice permutation); +XlaOp Transpose(const XlaOp& operand, absl::Span permutation); // Enqueues a reverse instruction onto the computation. The order of the // elements in the given dimensions is reversed (i.e., the element at index i // is moved to index dimension_size - 1 - i). -XlaOp Rev(const XlaOp& operand, tensorflow::gtl::ArraySlice dimensions); +XlaOp Rev(const XlaOp& operand, absl::Span dimensions); // Enqueues a sort (as increasing order) instruction onto the computation. // If only keys are provided: @@ -1972,22 +2003,21 @@ XlaOp Rev(const XlaOp& operand, tensorflow::gtl::ArraySlice dimensions); // the last dimension is chosen by default. // // If both keys and values are provided: -// * The keys and the values must tensors with the same dimensions. The +// * The keys and all values must be tensors with the same dimensions. The // element types of the tensors may be different. // * The result is a tuple that consists of a sorted tensor of keys (along the -// provided dimension, as above) as the first element, and a tensor with their -// corresponding values as the second element. -XlaOp Sort(XlaOp keys, absl::optional values = absl::nullopt, +// provided dimension, as above) as the first element, and tensors with their +// corresponding values as the other elements. +XlaOp Sort(const XlaOp& keys, absl::Span values = {}, int64 dimension = -1); // Enqueues a clamp instruction onto the computation. XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max); // Enqueues a map instruction onto the computation. -XlaOp Map(XlaBuilder* builder, tensorflow::gtl::ArraySlice operands, - const XlaComputation& computation, - tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice static_operands = {}); +XlaOp Map(XlaBuilder* builder, absl::Span operands, + const XlaComputation& computation, absl::Span dimensions, + absl::Span static_operands = {}); // Enqueues a N(mu, sigma) random number generation instruction onto the // computation. @@ -2014,7 +2044,7 @@ XlaOp ReducePrecision(const XlaOp& operand, const int exponent_bits, // Enqueues a Gather node onto the computation. XlaOp Gather(const XlaOp& input, const XlaOp& start_indices, const GatherDimensionNumbers& dimension_numbers, - tensorflow::gtl::ArraySlice slice_sizes); + absl::Span slice_sizes); // Enqueues a Scatter node onto the computation. XlaOp Scatter(const XlaOp& input, const XlaOp& scatter_indices, @@ -2072,7 +2102,7 @@ XlaOp CreateToken(XlaBuilder* builder); // Enqueues an AfterAll instruction which produces a token-shaped value and // takes a variadic number of token-shaped operands. The number of operands must // be greater than zero. Used for joining tokens. -XlaOp AfterAll(XlaBuilder* builder, tensorflow::gtl::ArraySlice tokens); +XlaOp AfterAll(XlaBuilder* builder, absl::Span tokens); // Normalizes operand across spatial and batch dimensions for each feature. // @@ -2116,12 +2146,12 @@ XlaOp BatchNormGrad(const XlaOp& operand, const XlaOp& scale, template XlaOp XlaBuilder::ConstantR0(NativeT value) { - return ConstantLiteral(*LiteralUtil::CreateR0(value)); + return ConstantLiteral(LiteralUtil::CreateR0(value)); } template -XlaOp XlaBuilder::ConstantR1(tensorflow::gtl::ArraySlice values) { - return ConstantLiteral(*LiteralUtil::CreateR1(values)); +XlaOp XlaBuilder::ConstantR1(absl::Span values) { + return ConstantLiteral(LiteralUtil::CreateR1(values)); } template @@ -2133,44 +2163,44 @@ XlaOp XlaBuilder::ConstantR1(int64 length, NativeT value) { } inline XlaOp XlaBuilder::ConstantR1(const tensorflow::core::Bitmap& values) { - return ConstantLiteral(*LiteralUtil::CreateR1(values)); + return ConstantLiteral(LiteralUtil::CreateR1(values)); } template XlaOp XlaBuilder::ConstantR2( std::initializer_list> values) { - return ConstantLiteral(*LiteralUtil::CreateR2(values)); + return ConstantLiteral(LiteralUtil::CreateR2(values)); } template XlaOp XlaBuilder::ConstantFromArrayWithLayout(const Array& values, const Layout& layout) { return ConstantLiteral( - *LiteralUtil::CreateFromArrayWithLayout(values, layout)); + LiteralUtil::CreateFromArrayWithLayout(values, layout)); } template XlaOp XlaBuilder::ConstantFromArray(const Array& values) { - return ConstantLiteral(*LiteralUtil::CreateFromArray(values)); + return ConstantLiteral(LiteralUtil::CreateFromArray(values)); } template XlaOp XlaBuilder::ConstantR2FromArray2DWithLayout( const Array2D& values, const Layout& layout) { return ConstantLiteral( - *LiteralUtil::CreateFromArrayWithLayout(values, layout)); + LiteralUtil::CreateFromArrayWithLayout(values, layout)); } template XlaOp XlaBuilder::ConstantR2FromArray2D(const Array2D& values) { - return ConstantLiteral(*LiteralUtil::CreateR2FromArray2D(values)); + return ConstantLiteral(LiteralUtil::CreateR2FromArray2D(values)); } template XlaOp XlaBuilder::ConstantR3FromArray3DWithLayout( const Array3D& values, const Layout& layout) { return ConstantLiteral( - *LiteralUtil::CreateR3FromArray3DWithLayout(values, layout)); + LiteralUtil::CreateR3FromArray3DWithLayout(values, layout)); } template @@ -2193,13 +2223,12 @@ XlaOp XlaBuilder::ConstantR4FromArray4D(const Array4D& values) { template XlaOp ConstantR0(XlaBuilder* builder, NativeT value) { - return ConstantLiteral(builder, *LiteralUtil::CreateR0(value)); + return ConstantLiteral(builder, LiteralUtil::CreateR0(value)); } template -XlaOp ConstantR1(XlaBuilder* builder, - tensorflow::gtl::ArraySlice values) { - return ConstantLiteral(builder, *LiteralUtil::CreateR1(values)); +XlaOp ConstantR1(XlaBuilder* builder, absl::Span values) { + return ConstantLiteral(builder, LiteralUtil::CreateR1(values)); } template @@ -2212,13 +2241,13 @@ XlaOp ConstantR1(XlaBuilder* builder, int64 length, NativeT value) { inline XlaOp ConstantR1(XlaBuilder* builder, const tensorflow::core::Bitmap& values) { - return ConstantLiteral(builder, *LiteralUtil::CreateR1(values)); + return ConstantLiteral(builder, LiteralUtil::CreateR1(values)); } template XlaOp ConstantR2(XlaBuilder* builder, std::initializer_list> values) { - return ConstantLiteral(builder, *LiteralUtil::CreateR2(values)); + return ConstantLiteral(builder, LiteralUtil::CreateR2(values)); } template @@ -2226,14 +2255,13 @@ XlaOp ConstantFromArrayWithLayout(XlaBuilder* builder, const Array& values, const Layout& layout) { return ConstantLiteral( - builder, - *LiteralUtil::CreateFromArrayWithLayout(values, layout)); + builder, LiteralUtil::CreateFromArrayWithLayout(values, layout)); } template XlaOp ConstantFromArray(XlaBuilder* builder, const Array& values) { return ConstantLiteral(builder, - *LiteralUtil::CreateFromArray(values)); + LiteralUtil::CreateFromArray(values)); } template @@ -2241,15 +2269,14 @@ XlaOp ConstantR2FromArray2DWithLayout(XlaBuilder* builder, const Array2D& values, const Layout& layout) { return ConstantLiteral( - builder, - *LiteralUtil::CreateFromArrayWithLayout(values, layout)); + builder, LiteralUtil::CreateFromArrayWithLayout(values, layout)); } template XlaOp ConstantR2FromArray2D(XlaBuilder* builder, const Array2D& values) { return ConstantLiteral(builder, - *LiteralUtil::CreateR2FromArray2D(values)); + LiteralUtil::CreateR2FromArray2D(values)); } template @@ -2258,7 +2285,7 @@ XlaOp ConstantR3FromArray3DWithLayout(XlaBuilder* builder, const Layout& layout) { return ConstantLiteral( builder, - *LiteralUtil::CreateR3FromArray3DWithLayout(values, layout)); + LiteralUtil::CreateR3FromArray3DWithLayout(values, layout)); } template diff --git a/tensorflow/compiler/xla/client/xla_builder_test.cc b/tensorflow/compiler/xla/client/xla_builder_test.cc index 49a15ec3b449bdec07aa6ecfbc40b7b9f62c3f4e..7c37ed00cd3dcc214fb0b36c0161d3c39a5bf8c8 100644 --- a/tensorflow/compiler/xla/client/xla_builder_test.cc +++ b/tensorflow/compiler/xla/client/xla_builder_test.cc @@ -320,6 +320,15 @@ TEST_F(XlaBuilderTest, AllToAll) { ShapeUtil::Equal(root->shape(), ShapeUtil::MakeShape(F32, {8, 8}))); } +TEST_F(XlaBuilderTest, CollectivePermute) { + XlaBuilder b(TestName()); + auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {5, 7}), "x"); + CollectivePermute(x, {{0, 1}, {1, 2}, {2, 3}}); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + auto root = module->entry_computation()->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kCollectivePermute); +} + TEST_F(XlaBuilderTest, ReportError) { XlaBuilder b(TestName()); auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {5, 7}), "x"); diff --git a/tensorflow/compiler/xla/client/xla_computation.cc b/tensorflow/compiler/xla/client/xla_computation.cc index 22c9e83bb2ae9e3e205bdd480b64c703e31c6ffd..c9870b65b91c1ebd7d44143faf215a2d5c2a2fc5 100644 --- a/tensorflow/compiler/xla/client/xla_computation.cc +++ b/tensorflow/compiler/xla/client/xla_computation.cc @@ -24,8 +24,8 @@ limitations under the License. namespace xla { StatusOr XlaComputation::GetProgramShape() const { - TF_RET_CHECK(proto_.has_program_shape()); - return proto_.program_shape(); + TF_RET_CHECK(proto_.has_host_program_shape()); + return proto_.host_program_shape(); } StatusOr> XlaComputation::Snapshot() const { diff --git a/tensorflow/compiler/xla/executable_run_options.cc b/tensorflow/compiler/xla/executable_run_options.cc index a472747bd174e3bbd352f07f2ab092e678b81073..0f9b591c70d4fd96147958d18bd5fb7dd78a7f3f 100644 --- a/tensorflow/compiler/xla/executable_run_options.cc +++ b/tensorflow/compiler/xla/executable_run_options.cc @@ -45,6 +45,16 @@ stream_executor::Stream* ExecutableRunOptions::stream() const { return stream_; } +ExecutableRunOptions& ExecutableRunOptions::set_host_to_device_stream( + stream_executor::Stream* stream) { + host_to_device_stream_ = stream; + return *this; +} + +stream_executor::Stream* ExecutableRunOptions::host_to_device_stream() const { + return host_to_device_stream_; +} + ExecutableRunOptions& ExecutableRunOptions::set_intra_op_thread_pool( const Eigen::ThreadPoolDevice* intra_op_thread_pool) { intra_op_thread_pool_ = intra_op_thread_pool; diff --git a/tensorflow/compiler/xla/executable_run_options.h b/tensorflow/compiler/xla/executable_run_options.h index 416131be006e6ecddb47651f8b684c1d91df4892..ba3217f31b55bd1428f67da6154a46c8bc304053 100644 --- a/tensorflow/compiler/xla/executable_run_options.h +++ b/tensorflow/compiler/xla/executable_run_options.h @@ -65,6 +65,13 @@ class ExecutableRunOptions { ExecutableRunOptions& set_stream(stream_executor::Stream* stream); stream_executor::Stream* stream() const; + // If set, this is the stream to perform any pre-computation transfers on. + // The platform of the stream must match the platform the executable was + // built for. A value of nullptr indicates the option has not been set. + ExecutableRunOptions& set_host_to_device_stream( + stream_executor::Stream* stream); + stream_executor::Stream* host_to_device_stream() const; + // Sets the thread pool device on which to run Eigen subcomputations. // Does not take ownership. ExecutableRunOptions& set_intra_op_thread_pool( @@ -90,6 +97,7 @@ class ExecutableRunOptions { const Eigen::ThreadPoolDevice* intra_op_thread_pool_ = nullptr; ExecutionProfile* execution_profile_ = nullptr; int rng_seed_ = 0; + stream_executor::Stream* host_to_device_stream_ = nullptr; }; } // namespace xla diff --git a/tensorflow/compiler/xla/index_util.cc b/tensorflow/compiler/xla/index_util.cc index 693dcb3a3eef37f92533f1add850395e51d4b910..3fadabcf5207097aa875d654320b930b1ed94ad3 100644 --- a/tensorflow/compiler/xla/index_util.cc +++ b/tensorflow/compiler/xla/index_util.cc @@ -27,7 +27,7 @@ limitations under the License. namespace xla { /* static */ int64 IndexUtil::MultidimensionalIndexToLinearIndex( - const Shape& shape, tensorflow::gtl::ArraySlice multi_index) { + const Shape& shape, absl::Span multi_index) { DCHECK_EQ(shape.dimensions_size(), multi_index.size()); // Padding and nested layouts not supported yet. DCHECK_EQ(0, shape.layout().padded_dimensions_size()); @@ -118,8 +118,8 @@ namespace xla { return multi_index; } -/* static */ bool IndexUtil::BumpIndices( - const Shape& shape, tensorflow::gtl::MutableArraySlice indices) { +/* static */ bool IndexUtil::BumpIndices(const Shape& shape, + absl::Span indices) { for (int64 dimno = indices.size() - 1; dimno >= 0; --dimno) { int64 limit = shape.dimensions(dimno); if (indices[dimno] + 1 < limit) { @@ -149,8 +149,8 @@ namespace xla { return stride; } -/* static */ bool IndexUtil::IndexInBounds( - const Shape& shape, tensorflow::gtl::ArraySlice index) { +/* static */ bool IndexUtil::IndexInBounds(const Shape& shape, + absl::Span index) { int64 rank = ShapeUtil::Rank(shape); if (rank != index.size()) { return false; @@ -163,9 +163,8 @@ namespace xla { return true; } -/* static */ int IndexUtil::CompareIndices( - tensorflow::gtl::ArraySlice lhs, - tensorflow::gtl::ArraySlice rhs) { +/* static */ int IndexUtil::CompareIndices(absl::Span lhs, + absl::Span rhs) { int64 rank = lhs.size(); CHECK_EQ(rhs.size(), rank); for (int64 dim = 0; dim < rank; ++dim) { diff --git a/tensorflow/compiler/xla/index_util.h b/tensorflow/compiler/xla/index_util.h index 142006f2626e83d3254f2de65fc28fd5d6694e53..2979cf87dde92893ce2151cb09b46c8db8473b31 100644 --- a/tensorflow/compiler/xla/index_util.h +++ b/tensorflow/compiler/xla/index_util.h @@ -20,9 +20,9 @@ limitations under the License. #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/macros.h" namespace xla { @@ -35,7 +35,7 @@ class IndexUtil { // on the shape and its layout. The first index in the multi_index is // dimension 0. static int64 MultidimensionalIndexToLinearIndex( - const Shape& shape, tensorflow::gtl::ArraySlice multi_index); + const Shape& shape, absl::Span multi_index); // Converts a linear index into multidimensional index (eg {x, y, z}) based on // the shape and its layout. The first index in the returned multidimensional @@ -58,8 +58,7 @@ class IndexUtil { // // Returns true iff the indices were successfully bumped; false if we've hit // the limit where it can no longer be bumped in-bounds. - static bool BumpIndices(const Shape& shape, - tensorflow::gtl::MutableArraySlice indices); + static bool BumpIndices(const Shape& shape, absl::Span indices); // Calculates the stride size (in number of elements, not byte size) of a // given logical shape dimension (from 0 to rank-1). If available, padded @@ -71,15 +70,14 @@ class IndexUtil { // Returns true iff the given multi-index is contained in the bounds for the // shape. - static bool IndexInBounds(const Shape& shape, - tensorflow::gtl::ArraySlice index); + static bool IndexInBounds(const Shape& shape, absl::Span index); // Compares the given indices in lexicographic order. lhs[0] and rhs[0] are // compared first, and lhs[rank-1] and rhs[rank-1] last. If lhs is larger, // then -1 is returned. If rhs is larger, then 1 is returned. Otherwise, 0 is // returned. - static int CompareIndices(tensorflow::gtl::ArraySlice lhs, - tensorflow::gtl::ArraySlice rhs); + static int CompareIndices(absl::Span lhs, + absl::Span rhs); private: TF_DISALLOW_COPY_AND_ASSIGN(IndexUtil); diff --git a/tensorflow/compiler/xla/index_util_test.cc b/tensorflow/compiler/xla/index_util_test.cc index 7c4efdee484d9530a69b31cbe3a0d69a8a3cffa7..93522d2ca87a7eba8d3c7533785c54e63ce507b0 100644 --- a/tensorflow/compiler/xla/index_util_test.cc +++ b/tensorflow/compiler/xla/index_util_test.cc @@ -142,13 +142,13 @@ TEST(IndexUtilTest, LinearToMultiToLinear) { TEST(IndexUtilTest, BumpIndices2x2) { auto shape = ShapeUtil::MakeShape(S32, {2, 2}); std::vector indices = {0, 0}; - EXPECT_TRUE(IndexUtil::BumpIndices(shape, &indices)); + EXPECT_TRUE(IndexUtil::BumpIndices(shape, absl::MakeSpan(indices))); EXPECT_THAT(indices, ::testing::ElementsAre(0, 1)); - EXPECT_TRUE(IndexUtil::BumpIndices(shape, &indices)); + EXPECT_TRUE(IndexUtil::BumpIndices(shape, absl::MakeSpan(indices))); EXPECT_THAT(indices, ::testing::ElementsAre(1, 0)); - EXPECT_TRUE(IndexUtil::BumpIndices(shape, &indices)); + EXPECT_TRUE(IndexUtil::BumpIndices(shape, absl::MakeSpan(indices))); EXPECT_THAT(indices, ::testing::ElementsAre(1, 1)); - EXPECT_FALSE(IndexUtil::BumpIndices(shape, &indices)); + EXPECT_FALSE(IndexUtil::BumpIndices(shape, absl::MakeSpan(indices))); } } // namespace diff --git a/tensorflow/compiler/xla/layout_util.cc b/tensorflow/compiler/xla/layout_util.cc index 61c26434b16513f59ba3aebb16f4706c5287e940..19667b7ed9d47896afd9a82a41de7997538b089b 100644 --- a/tensorflow/compiler/xla/layout_util.cc +++ b/tensorflow/compiler/xla/layout_util.cc @@ -56,7 +56,7 @@ void SetDefaultLayoutToContainer( } // namespace /* static */ Layout LayoutUtil::MakeLayout( - tensorflow::gtl::ArraySlice minor_to_major) { + absl::Span minor_to_major) { Layout layout; layout.set_format(DENSE); for (int64 dimension_number : minor_to_major) { @@ -65,8 +65,14 @@ void SetDefaultLayoutToContainer( return layout; } +/* static */ Layout LayoutUtil::MakeDescendingLayout(int64 rank) { + std::vector layout(rank); + std::iota(layout.rbegin(), layout.rend(), static_cast(0)); + return MakeLayout(layout); +} + /* static */ Layout LayoutUtil::MakeLayoutFromMajorToMinor( - tensorflow::gtl::ArraySlice major_to_minor) { + absl::Span major_to_minor) { Layout layout; layout.set_format(DENSE); for (int i = major_to_minor.size() - 1; i >= 0; i--) { @@ -169,7 +175,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) { } else if (ShapeUtil::IsArray(shape)) { if (!shape.has_layout()) { return InvalidArgument("shape %s does not have a layout", - ShapeUtil::HumanString(shape).c_str()); + ShapeUtil::HumanString(shape)); } return ValidateLayoutForShape(shape.layout(), shape); } else { @@ -177,7 +183,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) { if (shape.has_layout()) { return InvalidArgument( "shape of primitive type %s should not have a layout", - PrimitiveType_Name(shape.element_type()).c_str()); + PrimitiveType_Name(shape.element_type())); } return Status::OK(); } @@ -194,25 +200,25 @@ Layout CreateDefaultLayoutForRank(int64 rank) { layout.padded_dimensions_size() != 0) { return InvalidArgument( "shape of primitive type %s should not have a non-trivial layout", - PrimitiveType_Name(shape.element_type()).c_str()); + PrimitiveType_Name(shape.element_type())); } return Status::OK(); } - if (layout.format() == INVALID_FORMAT) { + if (layout.format() == INVALID_FORMAT || !Format_IsValid(layout.format())) { return InvalidArgument( "Layout does not have a valid format: layout {%s}, shape {%s}", - layout.ShortDebugString().c_str(), shape.ShortDebugString().c_str()); + layout.ShortDebugString(), shape.ShortDebugString()); } if (layout.format() == DENSE) { if (layout.minor_to_major_size() != ShapeUtil::Rank(shape)) { return InvalidArgument( "layout minor_to_major field contains %d elements, " - "but shape is rank %lld: {%s}; shape: %s", + "but shape is rank %d: {%s}; shape: %s", layout.minor_to_major_size(), ShapeUtil::Rank(shape), - absl::StrJoin(layout.minor_to_major(), ", ").c_str(), - shape.ShortDebugString().c_str()); + absl::StrJoin(layout.minor_to_major(), ", "), + shape.ShortDebugString()); } std::vector dimensions_in_layout(ShapeUtil::Rank(shape), false); @@ -221,12 +227,12 @@ Layout CreateDefaultLayoutForRank(int64 rank) { if (dim < 0 || dim >= ShapeUtil::Rank(shape)) { return InvalidArgument( "layout minor_to_major field has out-of-bounds value: %s", - HumanString(layout).c_str()); + HumanString(layout)); } if (dimensions_in_layout[dim]) { return InvalidArgument( "layout minor_to_major field has duplicate values: {%s}", - HumanString(layout).c_str()); + HumanString(layout)); } dimensions_in_layout[dim] = true; } @@ -234,14 +240,14 @@ Layout CreateDefaultLayoutForRank(int64 rank) { if (layout.padded_dimensions_size() > 0) { if (layout.padded_dimensions_size() != ShapeUtil::Rank(shape)) { return InvalidArgument( - "layout has %d padded dimensions, but shape is rank %lld", + "layout has %d padded dimensions, but shape is rank %d", layout.padded_dimensions_size(), ShapeUtil::Rank(shape)); } for (int i = 0; i < layout.padded_dimensions_size(); ++i) { if (layout.padded_dimensions(i) < shape.dimensions(i)) { return InvalidArgument( - "for dimension %d, dimension padding (%lld) is smaller than " - "the dimension size (%lld) of the shape", + "for dimension %d, dimension padding (%d) is smaller than " + "the dimension size (%d) of the shape", i, layout.padded_dimensions(i), shape.dimensions(i)); } } @@ -307,7 +313,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) { return false; } -/* static */ tensorflow::gtl::ArraySlice LayoutUtil::PaddedDimensions( +/* static */ absl::Span LayoutUtil::PaddedDimensions( const Shape& shape) { CHECK(IsDenseArray(shape)); return AsInt64Slice(shape.layout().padded_dimensions()); @@ -363,13 +369,13 @@ Layout CreateDefaultLayoutForRank(int64 rank) { return protobuf_util::ProtobufEquals(lhs, rhs); } -/* static */ tensorflow::gtl::ArraySlice LayoutUtil::MinorToMajor( +/* static */ absl::Span LayoutUtil::MinorToMajor( const Shape& shape) { CHECK(IsDenseArray(shape)); return AsInt64Slice(shape.layout().minor_to_major()); } -/* static */ tensorflow::gtl::ArraySlice LayoutUtil::MinorToMajor( +/* static */ absl::Span LayoutUtil::MinorToMajor( const Layout& layout) { CHECK(layout.format() == DENSE); return AsInt64Slice(layout.minor_to_major()); @@ -472,7 +478,7 @@ Status LayoutUtil::CopyLayoutBetweenShapes(const Shape& src, Shape* dst) { } /* static */ bool LayoutUtil::AreDimensionsConsecutive( - const Layout& layout, tensorflow::gtl::ArraySlice dims) { + const Layout& layout, absl::Span dims) { CHECK(IsDense(layout)); std::vector positions_in_layout; for (int64 dim : dims) { diff --git a/tensorflow/compiler/xla/layout_util.h b/tensorflow/compiler/xla/layout_util.h index 739bbe73675c7fb855627006028eafdf703d6540..af032b1cae4c5645d6c7da55b779cd0a7336592e 100644 --- a/tensorflow/compiler/xla/layout_util.h +++ b/tensorflow/compiler/xla/layout_util.h @@ -20,10 +20,10 @@ limitations under the License. #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" @@ -34,11 +34,15 @@ class LayoutUtil { public: // Creates a layout with the given minor-to-major dimension order. (This is a // convenience function for protobuf construction.) - static Layout MakeLayout(tensorflow::gtl::ArraySlice minor_to_major); + static Layout MakeLayout(absl::Span minor_to_major); // Similar to MakeLayout, but take indices in reverse order. static Layout MakeLayoutFromMajorToMinor( - tensorflow::gtl::ArraySlice major_to_minor); + absl::Span major_to_minor); + + // Returns a layout with descending ((i.e. {n, n-1, ..., 0}) minor-to-major + // dimensions. + static Layout MakeDescendingLayout(int64 rank); // Creates a sparse layout with the given maximum number of elements. (This is // a convenience function for protobuf construction.) @@ -104,8 +108,7 @@ class LayoutUtil { // Returns the padded_dimensions array for the given Shape. Requires that the // shape is an array and has a dense layout. - static tensorflow::gtl::ArraySlice PaddedDimensions( - const Shape& shape); + static absl::Span PaddedDimensions(const Shape& shape); // Returns the given index of the padded_dimensions array for the given Shape. // Requires that the shape is an array and has a dense layout. @@ -138,8 +141,8 @@ class LayoutUtil { // Returns the minor_to_major array for the given Shape. Requires that the // shape is an array and has a dense layout. - static tensorflow::gtl::ArraySlice MinorToMajor(const Shape& shape); - static tensorflow::gtl::ArraySlice MinorToMajor(const Layout& layout); + static absl::Span MinorToMajor(const Shape& shape); + static absl::Span MinorToMajor(const Layout& layout); // Major(0) is the most major logical dimension number, Major(1) is the // second-most-major logical dimension number and so on. @@ -196,7 +199,7 @@ class LayoutUtil { // Returns whether the given dimensions are consecutive in the given layout, // not necessarily in the order given. static bool AreDimensionsConsecutive(const Layout& layout, - tensorflow::gtl::ArraySlice dims); + absl::Span dims); // Compute a hash for `layout`. static size_t Hash(const Layout& layout); diff --git a/tensorflow/compiler/xla/layout_util_test.cc b/tensorflow/compiler/xla/layout_util_test.cc index e4c825450dcd45a8fbeaacbb2ad145f94307176f..f25dae6ff411133c74502039f441060f1329ffd4 100644 --- a/tensorflow/compiler/xla/layout_util_test.cc +++ b/tensorflow/compiler/xla/layout_util_test.cc @@ -27,15 +27,15 @@ namespace { class LayoutUtilTest : public ::testing::Test { protected: Shape MakeShapeWithLayout(PrimitiveType element_type, - tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice minor_to_major) { + absl::Span dimensions, + absl::Span minor_to_major) { Shape shape = ShapeUtil::MakeShape(element_type, dimensions); *shape.mutable_layout() = LayoutUtil::MakeLayout(minor_to_major); return shape; } Shape MakeShapeWithSparseLayout(PrimitiveType element_type, - tensorflow::gtl::ArraySlice dimensions, + absl::Span dimensions, int64 max_sparse_elements) { Shape shape = ShapeUtil::MakeShape(element_type, dimensions); *shape.mutable_layout() = LayoutUtil::MakeSparseLayout(max_sparse_elements); diff --git a/tensorflow/compiler/xla/legacy_flags/BUILD b/tensorflow/compiler/xla/legacy_flags/BUILD index 989035896b17609b6055f7dd5df839fc61d5f447..3e79129aafd234e5eab05d205f2017b54057795e 100644 --- a/tensorflow/compiler/xla/legacy_flags/BUILD +++ b/tensorflow/compiler/xla/legacy_flags/BUILD @@ -26,6 +26,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", ], ) @@ -39,6 +40,7 @@ tf_cc_test( "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/strings:str_format", ], ) @@ -75,5 +77,6 @@ tf_cc_test( "//tensorflow/core:lib", "//tensorflow/core:test", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", ], ) diff --git a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc index 0d3136b0cc6a3a695eacb98c16200e46a144c571..3ed3afcfcede20fbf5c7d4f004378817febeb4c7 100644 --- a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc +++ b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc @@ -57,6 +57,8 @@ void SetDebugOptionsDefaults(DebugOptions* flags) { // regression. flags->set_xla_cpu_enable_fast_math(true); flags->set_xla_gpu_enable_fast_math(true); + + flags->set_xla_force_host_platform_device_count(1); } // Allocates flag_values and flag_objects; this function must not be called more @@ -323,6 +325,17 @@ void AllocateFlags() { flag_values->xla_gpu_crash_on_verification_failures(), "Crashes the program on extra verification failures, e.g. cuDNN " "cross checking failures"), + tensorflow::Flag( + "xla_force_host_platform_device_count", + int32_setter_for( + &DebugOptions::set_xla_force_host_platform_device_count), + flag_values->xla_force_host_platform_device_count(), + "Force the host platform to pretend that there are these many " + "host \"devices\". All of these host devices are backed by the same" + "threadpool. Setting this to anything other than 1 can increase " + "overhead from context switching but we let the user override this " + "behavior to help run tests on the host that run models in parallel " + "across multiple devices."), }); ParseFlagsFromEnv(*flag_objects); } diff --git a/tensorflow/compiler/xla/legacy_flags/debug_options_parsers.h b/tensorflow/compiler/xla/legacy_flags/debug_options_parsers.h index acda43839598a660a7396922c07b0971ede0b247..ee7eb019c07cf898e48886955b18710146644cac 100644 --- a/tensorflow/compiler/xla/legacy_flags/debug_options_parsers.h +++ b/tensorflow/compiler/xla/legacy_flags/debug_options_parsers.h @@ -21,7 +21,6 @@ limitations under the License. #include "absl/strings/str_split.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/xla.pb.h" -#include "tensorflow/core/lib/strings/stringprintf.h" namespace xla { namespace legacy_flags { diff --git a/tensorflow/compiler/xla/legacy_flags/parse_flags_from_env_test.cc b/tensorflow/compiler/xla/legacy_flags/parse_flags_from_env_test.cc index 7b6ae311c1099dccb8dceb2f49743c1b185cd5ab..138c0c852e2bb0527d171f25b4d96cedc5671516 100644 --- a/tensorflow/compiler/xla/legacy_flags/parse_flags_from_env_test.cc +++ b/tensorflow/compiler/xla/legacy_flags/parse_flags_from_env_test.cc @@ -21,8 +21,8 @@ limitations under the License. #include #include +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/subprocess.h" #include "tensorflow/core/platform/test.h" @@ -106,8 +106,8 @@ TEST(ParseFlagsFromEnv, File) { if (tmp_dir == nullptr) { tmp_dir = kTempDir; } - string tmp_file = tensorflow::strings::Printf("%s/parse_flags_from_env.%d", - tmp_dir, getpid()); + string tmp_file = + absl::StrFormat("%s/parse_flags_from_env.%d", tmp_dir, getpid()); FILE* fp = fopen(tmp_file.c_str(), "w"); CHECK_NE(fp, nullptr) << "can't write to " << tmp_file; for (int i = 0; kTestFlagString[i] != '\0'; i++) { diff --git a/tensorflow/compiler/xla/literal.cc b/tensorflow/compiler/xla/literal.cc index 0c0b619d507204df0abbfb8ef7f3d142bd3e9290..656ce720a13d5c9622e9dc05ae04ddcac8cbeee5 100644 --- a/tensorflow/compiler/xla/literal.cc +++ b/tensorflow/compiler/xla/literal.cc @@ -24,6 +24,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/index_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -33,7 +34,6 @@ limitations under the License. #include "tensorflow/core/lib/core/casts.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/hash/hash.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -41,7 +41,7 @@ namespace xla { namespace { using absl::StrCat; -using tensorflow::strings::Printf; +using absl::StrFormat; constexpr bool kLittleEndian = __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__; @@ -73,7 +73,7 @@ std::ostream& operator<<(std::ostream& out, const Literal& literal) { MutableLiteralBase::StrideConfig::StrideConfig( const Shape& source_shape, const Shape& dest_shape, - tensorflow::gtl::ArraySlice dimensions) + absl::Span dimensions) : dimensions(dimensions), base(dimensions.size(), 0), step(dimensions.size(), 1) { @@ -174,9 +174,9 @@ Literal& Literal::operator=(Literal&& other) { return *this; } -std::unique_ptr LiteralBase::CreateFromShape(const Shape& shape) { - auto literal = absl::make_unique(shape); - literal->root_piece_->ForEachMutableSubpiece( +Literal LiteralBase::CreateFromShape(const Shape& shape) { + Literal literal(shape); + literal.root_piece_->ForEachMutableSubpiece( [&](const ShapeIndex& index, Piece* piece) { if (ShapeUtil::IsArray(piece->subshape())) { memset(piece->untyped_data(), 0, piece->size_bytes()); @@ -197,14 +197,13 @@ SparseIndexArray* MutableLiteralBase::sparse_indices( template Status MutableLiteralBase::CopySliceFromInternal( - const LiteralBase& src_literal, tensorflow::gtl::ArraySlice src_base, - tensorflow::gtl::ArraySlice dest_base, - tensorflow::gtl::ArraySlice copy_size) { + const LiteralBase& src_literal, absl::Span src_base, + absl::Span dest_base, absl::Span copy_size) { TF_RET_CHECK(ShapeUtil::Rank(src_literal.shape()) == src_base.size()); TF_RET_CHECK(ShapeUtil::Rank(shape()) == dest_base.size()); auto linear_index = [](const Shape& shape, - tensorflow::gtl::ArraySlice multi_index) { + absl::Span multi_index) { return IndexUtil::MultidimensionalIndexToLinearIndex(shape, multi_index); }; @@ -232,7 +231,7 @@ Status MutableLiteralBase::CopySliceFromInternal( MutableLiteralBase::StrideConfig stride_config(src_literal.shape(), shape(), copy_size); - auto copy_proc = [&](tensorflow::gtl::ArraySlice indexes) { + auto copy_proc = [&](absl::Span indexes) { // Map from multi-dimensional index, to source index. std::transform(indexes.begin(), indexes.end(), src_base.begin(), src_indexes.begin(), std::plus()); @@ -257,10 +256,9 @@ Status MutableLiteralBase::CopySliceFromInternal( return Status::OK(); } -Status MutableLiteralBase::CopyElementFrom( - const LiteralSlice& src_literal, - tensorflow::gtl::ArraySlice src_index, - tensorflow::gtl::ArraySlice dest_index) { +Status MutableLiteralBase::CopyElementFrom(const LiteralSlice& src_literal, + absl::Span src_index, + absl::Span dest_index) { DCHECK_EQ(shape().element_type(), src_literal.shape().element_type()); const int64 src_linear_index = IndexUtil::MultidimensionalIndexToLinearIndex( src_literal.shape(), src_index); @@ -280,8 +278,8 @@ Status MutableLiteralBase::CopyElementFrom( return Status::OK(); } -/* static */ StatusOr> -MutableLiteralBase::CreateFromProto(const LiteralProto& proto) { +/* static */ StatusOr MutableLiteralBase::CreateFromProto( + const LiteralProto& proto) { if (!proto.has_shape()) { return InvalidArgument("LiteralProto has no shape"); } @@ -289,9 +287,11 @@ MutableLiteralBase::CreateFromProto(const LiteralProto& proto) { return InvalidArgument("LiteralProto has no layout"); } - auto literal = absl::make_unique(proto.shape()); + TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(proto.shape())); - TF_RETURN_IF_ERROR(literal->root_piece_->ForEachMutableSubpieceWithStatus( + Literal literal(proto.shape()); + + TF_RETURN_IF_ERROR(literal.root_piece_->ForEachMutableSubpieceWithStatus( [&](const ShapeIndex& index, Piece* piece) { const LiteralProto* proto_element = &proto; for (int64 i : index) { @@ -303,7 +303,7 @@ MutableLiteralBase::CreateFromProto(const LiteralProto& proto) { if (proto_element->tuple_literals_size() != ShapeUtil::TupleElementCount(piece->subshape())) { return InvalidArgument( - "Expected %lld tuple elements in LiteralProto, has %d", + "Expected %d tuple elements in LiteralProto, has %d", ShapeUtil::TupleElementCount(piece->subshape()), proto_element->tuple_literals_size()); } @@ -355,9 +355,9 @@ namespace { // Copies the elements in 'src' to 'dest'. The shape and layout of the data in // the array slices are indicated by dest_shape and src_shape respectively. template -void CopyElementsBetween(tensorflow::gtl::MutableArraySlice dest, - tensorflow::gtl::ArraySlice src, - const Shape& dest_shape, const Shape& src_shape) { +void CopyElementsBetween(absl::Span dest, + absl::Span src, const Shape& dest_shape, + const Shape& src_shape) { CHECK(ShapeUtil::Compatible(dest_shape, src_shape)); if (ShapeUtil::IsZeroElementArray(dest_shape)) { return; @@ -366,7 +366,7 @@ void CopyElementsBetween(tensorflow::gtl::MutableArraySlice dest, do { dest[IndexUtil::MultidimensionalIndexToLinearIndex(dest_shape, index)] = src[IndexUtil::MultidimensionalIndexToLinearIndex(src_shape, index)]; - } while (IndexUtil::BumpIndices(dest_shape, &index)); + } while (IndexUtil::BumpIndices(dest_shape, absl::MakeSpan(index))); } } // namespace @@ -404,7 +404,7 @@ Status LiteralBase::Piece::CopyFrom(const LiteralBase::Piece& src) { default: return Unimplemented( "Copying a Literal object with element type %s is not implemented.", - PrimitiveType_Name(subshape().element_type()).c_str()); + PrimitiveType_Name(subshape().element_type())); } } return Status::OK(); @@ -420,8 +420,8 @@ Status MutableLiteralBase::CopyFrom(const LiteralSlice& src_literal, if (!ShapeUtil::Compatible(dest_subshape, src_subshape)) { return InvalidArgument( "Destination subshape incompatible with source subshape: %s vs %s", - ShapeUtil::HumanString(dest_subshape).c_str(), - ShapeUtil::HumanString(src_subshape).c_str()); + ShapeUtil::HumanString(dest_subshape), + ShapeUtil::HumanString(src_subshape)); } return root_piece_->ForEachMutableSubpieceWithStatus( [&](const ShapeIndex& index, Piece* piece) { @@ -458,8 +458,8 @@ Status Literal::MoveFrom(Literal&& src_literal, if (!ShapeUtil::Equal(dest_subshape, src_literal.shape())) { return InvalidArgument( "Destination subshape not equal to source shape: %s vs %s", - ShapeUtil::HumanString(dest_subshape).c_str(), - ShapeUtil::HumanString(src_literal.shape()).c_str()); + ShapeUtil::HumanString(dest_subshape), + ShapeUtil::HumanString(src_literal.shape())); } src_literal.root_piece_->ForEachSubpiece( @@ -487,11 +487,10 @@ Status Literal::MoveFrom(Literal&& src_literal, return Status::OK(); } -Status MutableLiteralBase::CopySliceFrom( - const LiteralSlice& src_literal, - tensorflow::gtl::ArraySlice src_base, - tensorflow::gtl::ArraySlice dest_base, - tensorflow::gtl::ArraySlice copy_size) { +Status MutableLiteralBase::CopySliceFrom(const LiteralSlice& src_literal, + absl::Span src_base, + absl::Span dest_base, + absl::Span copy_size) { TF_RET_CHECK(ShapeUtil::IsArray(shape())) << ShapeUtil::HumanString(shape()); TF_RET_CHECK(ShapeUtil::IsArray(src_literal.shape())) << ShapeUtil::HumanString(src_literal.shape()); @@ -559,40 +558,38 @@ void MutableLiteralBase::PopulateR1(const tensorflow::core::Bitmap& values) { } } -std::unique_ptr LiteralBase::Relayout( - const Layout& new_layout, const ShapeIndex& shape_index) const { +Literal LiteralBase::Relayout(const Layout& new_layout, + const ShapeIndex& shape_index) const { // Create new shape with 'new_layout' set at the given shape index. Shape new_shape = shape(); Shape* subshape = ShapeUtil::GetMutableSubshape(&new_shape, shape_index); TF_CHECK_OK(LayoutUtil::ValidateLayoutForShape(new_layout, *subshape)); *subshape->mutable_layout() = new_layout; - auto result = absl::make_unique(new_shape); - TF_CHECK_OK(result->CopyFrom(*this)); + Literal result(new_shape); + TF_CHECK_OK(result.CopyFrom(*this)); return result; } -std::unique_ptr LiteralBase::Relayout( - const Shape& shape_with_layout) const { +Literal LiteralBase::Relayout(const Shape& shape_with_layout) const { CHECK(ShapeUtil::Compatible(shape_with_layout, shape())) << "Given shape_with_layout " << ShapeUtil::HumanString(shape_with_layout) << " not compatible with literal shape " << ShapeUtil::HumanString(shape()); - std::unique_ptr result = CreateFromShape(shape_with_layout); + Literal result = CreateFromShape(shape_with_layout); ShapeUtil::ForEachSubshape( - result->shape(), + result.shape(), [this, &result](const Shape& subshape, const ShapeIndex& index) { if (ShapeUtil::IsArray(subshape)) { - TF_CHECK_OK(result->CopyFrom(*this, - /*dest_shape_index=*/index, - /*src_shape_index=*/index)); + TF_CHECK_OK(result.CopyFrom(*this, + /*dest_shape_index=*/index, + /*src_shape_index=*/index)); } }); return result; } -StatusOr> LiteralBase::Broadcast( - const Shape& result_shape, - tensorflow::gtl::ArraySlice dimensions) const { +StatusOr LiteralBase::Broadcast( + const Shape& result_shape, absl::Span dimensions) const { if (!ShapeUtil::IsArray(shape())) { return InvalidArgument("Broadcast only supports arrays."); } @@ -602,20 +599,20 @@ StatusOr> LiteralBase::Broadcast( result_shape.dimensions(dimensions[i])); } - std::unique_ptr result = absl::make_unique(result_shape); + Literal result(result_shape); // scratch_source_index is temporary storage space for the computed index into // the input literal. We put it here to avoid allocating an std::vector in // every iteration of ShapeUtil::ForEachIndex. std::vector scratch_source_index(shape().dimensions_size()); - char* dest_data = static_cast(result->untyped_data()); + char* dest_data = static_cast(result.untyped_data()); const char* source_data = static_cast(untyped_data()); const int64 primitive_size = ShapeUtil::ByteSizeOfPrimitiveType(shape().element_type()); ShapeUtil::ForEachIndex( - result_shape, [&](tensorflow::gtl::ArraySlice output_index) { + result_shape, [&](absl::Span output_index) { for (int64 i = 0; i < dimensions.size(); ++i) { scratch_source_index[i] = output_index[dimensions[i]]; } @@ -631,37 +628,36 @@ StatusOr> LiteralBase::Broadcast( return std::move(result); } -StatusOr> LiteralBase::Reshape( - tensorflow::gtl::ArraySlice dimensions) const { +StatusOr LiteralBase::Reshape( + absl::Span dimensions) const { if (!ShapeUtil::IsArray(shape())) { return InvalidArgument("Reshape does not support tuples."); } - std::unique_ptr output; + Literal output; if (!LayoutUtil::IsMonotonicWithDim0Major(shape().layout())) { output = Relayout(LayoutUtil::GetDefaultLayoutForRank(ShapeUtil::Rank(shape()))); } else { - output = CloneToUnique(); + output = Clone(); } // Because the layout is monotonic, we can simply reuse the same sequence of // values without changing their order. - *output->mutable_shape_do_not_use() = + *output.mutable_shape_do_not_use() = ShapeUtil::MakeShape(shape().element_type(), dimensions); int64 elements_before = ShapeUtil::ElementsIn(shape()); - int64 elements_after = ShapeUtil::ElementsIn(output->shape()); + int64 elements_after = ShapeUtil::ElementsIn(output.shape()); if (elements_before != elements_after) { return InvalidArgument( "Shapes before and after Literal::Reshape have different numbers " "of elements: %s vs %s.", - ShapeUtil::HumanString(shape()).c_str(), - ShapeUtil::HumanString(output->shape()).c_str()); + ShapeUtil::HumanString(shape()), + ShapeUtil::HumanString(output.shape())); } return std::move(output); } -std::unique_ptr LiteralBase::Transpose( - tensorflow::gtl::ArraySlice permutation) const { +Literal LiteralBase::Transpose(absl::Span permutation) const { CHECK(ShapeUtil::IsArray(shape())) << "Tuple is not supported for transpose"; CHECK(IsPermutation(permutation, ShapeUtil::Rank(shape()))) << "Given permutation is not a permutation of dimension numbers"; @@ -691,33 +687,31 @@ std::unique_ptr LiteralBase::Transpose( for (auto index : LayoutUtil::MinorToMajor(shape())) { layout->add_minor_to_major(inverse_permutation[index]); } - auto new_literal = absl::make_unique(permuted_shape); - DCHECK_EQ(ShapeUtil::ByteSizeOf(new_literal->shape()), + Literal new_literal(permuted_shape); + DCHECK_EQ(ShapeUtil::ByteSizeOf(new_literal.shape()), ShapeUtil::ByteSizeOf(shape())); - std::memcpy(new_literal->untyped_data(), untyped_data(), size_bytes()); + std::memcpy(new_literal.untyped_data(), untyped_data(), size_bytes()); return new_literal; } template -std::unique_ptr LiteralBase::SliceInternal( - const Shape& result_shape, - tensorflow::gtl::ArraySlice start_indices) const { - auto result_literal = absl::make_unique(result_shape); +Literal LiteralBase::SliceInternal( + const Shape& result_shape, absl::Span start_indices) const { + Literal result_literal(result_shape); DimensionVector new_indices(ShapeUtil::Rank(result_shape)); - result_literal->EachCell( - [&](tensorflow::gtl::ArraySlice indices, NativeT /*value*/) { + result_literal.EachCell( + [&](absl::Span indices, NativeT /*value*/) { for (int64 i = 0; i < ShapeUtil::Rank(result_shape); ++i) { new_indices[i] = indices[i] + start_indices[i]; } NativeT value = Get(new_indices); - result_literal->Set(indices, value); + result_literal.Set(indices, value); }); return result_literal; } -std::unique_ptr LiteralBase::Slice( - tensorflow::gtl::ArraySlice start_indices, - tensorflow::gtl::ArraySlice limit_indices) const { +Literal LiteralBase::Slice(absl::Span start_indices, + absl::Span limit_indices) const { CHECK(ShapeUtil::IsArray(shape())) << "tuple is not supported for slice"; DimensionVector result_dimensions; @@ -733,16 +727,34 @@ std::unique_ptr LiteralBase::Slice( ShapeUtil::MakeShapeWithLayout(shape().element_type(), result_dimensions, LayoutUtil::MinorToMajor(shape())); switch (result_shape.element_type()) { - case F32: - return SliceInternal(result_shape, start_indices); + case PRED: + return SliceInternal(result_shape, start_indices); + case U8: + return SliceInternal(result_shape, start_indices); + case U16: + return SliceInternal(result_shape, start_indices); + case U32: + return SliceInternal(result_shape, start_indices); + case U64: + return SliceInternal(result_shape, start_indices); + case S8: + return SliceInternal(result_shape, start_indices); + case S16: + return SliceInternal(result_shape, start_indices); + case S32: + return SliceInternal(result_shape, start_indices); + case S64: + return SliceInternal(result_shape, start_indices); + case F16: + return SliceInternal(result_shape, start_indices); case BF16: return SliceInternal(result_shape, start_indices); + case F32: + return SliceInternal(result_shape, start_indices); + case F64: + return SliceInternal(result_shape, start_indices); case C64: return SliceInternal(result_shape, start_indices); - case S32: - return SliceInternal(result_shape, start_indices); - case U32: - return SliceInternal(result_shape, start_indices); default: LOG(FATAL) << "not yet implemented: " << PrimitiveType_Name(result_shape.element_type()); @@ -755,13 +767,7 @@ Literal LiteralBase::Clone() const { return result; } -std::unique_ptr LiteralBase::CloneToUnique() const { - auto result = absl::make_unique(shape()); - TF_CHECK_OK(result->CopyFrom(*this)); - return result; -} - -string LiteralBase::GetAsString(tensorflow::gtl::ArraySlice multi_index, +string LiteralBase::GetAsString(absl::Span multi_index, const ShapeIndex& shape_index) const { const Shape& subshape = ShapeUtil::GetSubshape(shape(), shape_index); CHECK(LayoutUtil::IsDenseArray(subshape)); @@ -858,7 +864,7 @@ string LiteralBase::GetSparseElementAsString( } StatusOr LiteralBase::GetIntegralAsS64( - tensorflow::gtl::ArraySlice multi_index) const { + absl::Span multi_index) const { CHECK(LayoutUtil::IsDenseArray(shape())); switch (shape().element_type()) { case PRED: @@ -874,9 +880,8 @@ StatusOr LiteralBase::GetIntegralAsS64( case U64: return Get(multi_index); default: - return FailedPrecondition( - "Array element type is not integral: %s", - PrimitiveType_Name(shape().element_type()).c_str()); + return FailedPrecondition("Array element type is not integral: %s", + PrimitiveType_Name(shape().element_type())); } } @@ -901,8 +906,8 @@ size_t LiteralBase::Hash() const { return hash_value; } -Status MutableLiteralBase::SetIntegralAsS64( - tensorflow::gtl::ArraySlice multi_index, int64 value) { +Status MutableLiteralBase::SetIntegralAsS64(absl::Span multi_index, + int64 value) { CHECK(LayoutUtil::IsDenseArray(shape())); switch (shape().element_type()) { case PRED: @@ -924,14 +929,13 @@ Status MutableLiteralBase::SetIntegralAsS64( Set(multi_index, value); break; default: - return FailedPrecondition( - "Array element type is not integral: %s", - PrimitiveType_Name(shape().element_type()).c_str()); + return FailedPrecondition("Array element type is not integral: %s", + PrimitiveType_Name(shape().element_type())); } return Status::OK(); } -tensorflow::gtl::ArraySlice LiteralBase::GetSparseIndex( +absl::Span LiteralBase::GetSparseIndex( int64 sparse_element_number, const ShapeIndex& shape_index) const { const Piece& p = piece(shape_index); CHECK_GE(sparse_element_number, 0); @@ -1000,7 +1004,7 @@ void LiteralBase::Piece::SortSparseElementsInternal() { auto values = data(); CHECK_LE(num_elements, values.size()); sparse_indices()->SortWithValues( - tensorflow::gtl::MutableArraySlice(values.data(), num_elements)); + absl::Span(values.data(), num_elements)); } namespace { @@ -1066,8 +1070,7 @@ void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index, CHECK(LayoutUtil::IsDenseArray(subshape)); - auto element_to_string = - [&](tensorflow::gtl::ArraySlice indices) -> string { + auto element_to_string = [&](absl::Span indices) -> string { PrimitiveType element_type = subshape.element_type(); if (element_type == PRED) { // We display predicates in a densely packed form. @@ -1116,9 +1119,9 @@ void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index, pieces->push_back(shape_to_string(subshape)); pieces->push_back(" {\n"); for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) { - pieces->push_back(Printf(" { /*i0=%lld*/\n", i0)); + pieces->push_back(StrFormat(" { /*i0=%d*/\n", i0)); for (int64 i1 = 0; i1 < subshape.dimensions(1); ++i1) { - pieces->push_back(Printf(" { /*i1=%lld*/\n", i1)); + pieces->push_back(StrFormat(" { /*i1=%d*/\n", i1)); for (int64 i2 = 0; i2 < subshape.dimensions(2); ++i2) { pieces->push_back(" {"); for (int64 i3 = 0; i3 < subshape.dimensions(3); ++i3) { @@ -1136,11 +1139,11 @@ void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index, pieces->push_back(shape_to_string(subshape)); pieces->push_back(" {\n"); for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) { - pieces->push_back(Printf(" { /*i0=%lld*/\n", i0)); + pieces->push_back(StrFormat(" { /*i0=%d*/\n", i0)); for (int64 i1 = 0; i1 < subshape.dimensions(1); ++i1) { - pieces->push_back(Printf(" { /*i1=%lld*/\n", i1)); + pieces->push_back(StrFormat(" { /*i1=%d*/\n", i1)); for (int64 i2 = 0; i2 < subshape.dimensions(2); ++i2) { - pieces->push_back(Printf(" { /*i2=%lld*/\n", i2)); + pieces->push_back(StrFormat(" { /*i2=%d*/\n", i2)); for (int64 i3 = 0; i3 < subshape.dimensions(3); ++i3) { pieces->push_back(" {"); for (int64 i4 = 0; i4 < subshape.dimensions(4); ++i4) { @@ -1162,7 +1165,7 @@ void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index, pieces->push_back(shape_to_string(subshape)); pieces->push_back(" {"); literal.EachCellAsString( - [&](tensorflow::gtl::ArraySlice indices, const string& value) { + [&](absl::Span indices, const string& value) { pieces->push_back(" "); pieces->push_back(value); }); @@ -1185,7 +1188,7 @@ string LiteralBase::ToString(bool print_layout) const { } void LiteralBase::EachCellAsString( - const std::function indices, + const std::function indices, const string& value)>& per_cell) const { if (ShapeUtil::IsZeroElementArray(shape())) { return; @@ -1194,19 +1197,19 @@ void LiteralBase::EachCellAsString( shape(), /*linear_index=*/0); do { per_cell(indices, GetAsString(indices)); - } while (IndexUtil::BumpIndices(shape(), &indices)); + } while (IndexUtil::BumpIndices(shape(), absl::MakeSpan(indices))); } namespace { template -std::unique_ptr ConvertBetweenNativeTypesWithConverter( - const LiteralBase& src_literal, const ConverterType& converter) { +Literal ConvertBetweenNativeTypesWithConverter(const LiteralBase& src_literal, + const ConverterType& converter) { CHECK(ShapeUtil::IsArray(src_literal.shape())); - auto result_literal = absl::make_unique(ShapeUtil::ChangeElementType( + Literal result_literal(ShapeUtil::ChangeElementType( src_literal.shape(), primitive_util::NativeToPrimitiveType())); auto src_data = src_literal.data(); - auto dest_data = result_literal->template data(); + auto dest_data = result_literal.template data(); int64 num_elements = src_literal.element_count(); for (int64 i = 0; i < num_elements; ++i) { @@ -1216,8 +1219,7 @@ std::unique_ptr ConvertBetweenNativeTypesWithConverter( } template -std::unique_ptr ConvertBetweenNativeTypes( - const LiteralBase& src_literal) { +Literal ConvertBetweenNativeTypes(const LiteralBase& src_literal) { auto converter = [](NativeSrcT src) { return static_cast(src); }; return ConvertBetweenNativeTypesWithConverter( src_literal, converter); @@ -1225,7 +1227,7 @@ std::unique_ptr ConvertBetweenNativeTypes( template typename std::enable_if<(sizeof(NativeSrcT) == sizeof(NativeDestT)), - std::unique_ptr>::type + Literal>::type BitcastBetweenNativeTypes(const LiteralBase& src_literal) { auto converter = [](NativeSrcT src) { return tensorflow::bit_cast(src); @@ -1240,22 +1242,20 @@ BitcastBetweenNativeTypes(const LiteralBase& src_literal) { // identical sizes higher up. template typename std::enable_if<(sizeof(NativeSrcT) != sizeof(NativeDestT)), - std::unique_ptr>::type + Literal>::type BitcastBetweenNativeTypes(const LiteralBase& src_literal) { LOG(FATAL) << "Invalid bitcast between types of different sizes."; } template -std::unique_ptr ConvertToC64(const LiteralBase& src_literal) { +Literal ConvertToC64(const LiteralBase& src_literal) { CHECK(ShapeUtil::IsArray(src_literal.shape())); - auto result_literal = absl::make_unique( + Literal result_literal( ShapeUtil::ChangeElementType(src_literal.shape(), C64)); using NativeSrcT = typename primitive_util::PrimitiveTypeToNative::type; - tensorflow::gtl::ArraySlice src_data = - src_literal.data(); - tensorflow::gtl::MutableArraySlice dest_data = - result_literal->data(); + absl::Span src_data = src_literal.data(); + absl::Span dest_data = result_literal.data(); int64 num_elements = src_literal.element_count(); for (int64 i = 0; i < num_elements; ++i) { dest_data[i] = complex64(static_cast(src_data[i]), 0); @@ -1264,8 +1264,7 @@ std::unique_ptr ConvertToC64(const LiteralBase& src_literal) { } template -std::unique_ptr ConvertIfTypesMatch(const LiteralBase& src_literal, - bool bitcast) { +Literal ConvertIfTypesMatch(const LiteralBase& src_literal, bool bitcast) { CHECK_EQ(primitive_src_type, src_literal.shape().element_type()); if (bitcast) { return BitcastBetweenNativeTypes< @@ -1283,9 +1282,9 @@ std::unique_ptr ConvertIfTypesMatch(const LiteralBase& src_literal, } template -StatusOr> ConvertIfDestTypeMatches( - const LiteralBase& src_literal, PrimitiveType primitive_dest_type, - bool bitcast) { +StatusOr ConvertIfDestTypeMatches(const LiteralBase& src_literal, + PrimitiveType primitive_dest_type, + bool bitcast) { switch (primitive_dest_type) { #define CONVERT_IF_TYPES_MATCH(type) \ case (type): \ @@ -1312,18 +1311,17 @@ StatusOr> ConvertIfDestTypeMatches( default: break; } - return Unimplemented( - "Converting from type %s to type %s is not implemented.", - PrimitiveType_Name(src_literal.shape().element_type()).c_str(), - PrimitiveType_Name(primitive_dest_type).c_str()); + return Unimplemented("Converting from type %s to type %s is not implemented.", + PrimitiveType_Name(src_literal.shape().element_type()), + PrimitiveType_Name(primitive_dest_type)); } -StatusOr> ConvertSwitch( - const LiteralBase& literal, PrimitiveType primitive_dest_type, - bool bitcast) { +StatusOr ConvertSwitch(const LiteralBase& literal, + PrimitiveType primitive_dest_type, + bool bitcast) { TF_RET_CHECK(ShapeUtil::IsArray(literal.shape())); if (literal.shape().element_type() == primitive_dest_type) { - return literal.CloneToUnique(); + return literal.Clone(); } switch (literal.shape().element_type()) { #define CONVERT_IF_DEST_TYPE_MATCHES(type) \ @@ -1344,47 +1342,37 @@ StatusOr> ConvertSwitch( #undef CONVERT_IF_DEST_TYPE_MATCHES // Other types are not yet supported. default: - return Unimplemented( - "%s from type %s to type %s is not implemented.", - (bitcast ? "Bitcast converting" : "Converting"), - PrimitiveType_Name(literal.shape().element_type()).c_str(), - PrimitiveType_Name(primitive_dest_type).c_str()); + return Unimplemented("%s from type %s to type %s is not implemented.", + (bitcast ? "Bitcast converting" : "Converting"), + PrimitiveType_Name(literal.shape().element_type()), + PrimitiveType_Name(primitive_dest_type)); } } } // namespace -StatusOr> LiteralBase::Convert( +StatusOr LiteralBase::Convert( PrimitiveType primitive_dest_type) const { return ConvertSwitch(*this, primitive_dest_type, /*bitcast=*/false); } -StatusOr> LiteralBase::BitcastConvert( +StatusOr LiteralBase::BitcastConvert( PrimitiveType primitive_dest_type) const { if (primitive_util::BitWidth(shape().element_type()) != primitive_util::BitWidth(primitive_dest_type)) { return InvalidArgument( "Cannot bitcast convert from %s to %s, bit widths are different: %d != " "%d", - PrimitiveType_Name(shape().element_type()).c_str(), - PrimitiveType_Name(primitive_dest_type).c_str(), + PrimitiveType_Name(shape().element_type()), + PrimitiveType_Name(primitive_dest_type), primitive_util::BitWidth(shape().element_type()), primitive_util::BitWidth(primitive_dest_type)); } return ConvertSwitch(*this, primitive_dest_type, /*bitcast=*/true); } -StatusOr> LiteralBase::ConvertToShape( - const Shape& dest_shape, bool round_f32_to_bf16) const { +StatusOr LiteralBase::ConvertToShape(const Shape& dest_shape) const { if (!ShapeUtil::IsTuple(dest_shape)) { - if (round_f32_to_bf16 && shape().element_type() == F32 && - dest_shape.element_type() == BF16) { - auto converter = [](float src) { - return tensorflow::bfloat16::round_to_bfloat16(src); - }; - return ConvertBetweenNativeTypesWithConverter(*this, - converter); - } return Convert(dest_shape.element_type()); } std::vector elements; @@ -1393,15 +1381,13 @@ StatusOr> LiteralBase::ConvertToShape( TF_ASSIGN_OR_RETURN( auto new_element, element.ConvertToShape(ShapeUtil::GetSubshape(dest_shape, {i}))); - elements.push_back(std::move(*new_element)); + elements.push_back(std::move(new_element)); } - auto converted = absl::make_unique(); - *converted = MutableLiteralBase::MoveIntoTuple(&elements); - return std::move(converted); + return MutableLiteralBase::MoveIntoTuple(absl::MakeSpan(elements)); } /* static */ Literal MutableLiteralBase::MoveIntoTuple( - tensorflow::gtl::MutableArraySlice elements) { + absl::Span elements) { std::vector element_shapes; for (const Literal& element : elements) { element_shapes.push_back(element.shape()); @@ -1492,7 +1478,7 @@ bool LiteralBase::operator==(const LiteralBase& other) const { namespace { template -static bool AllElementsEqualValue(tensorflow::gtl::ArraySlice data, +static bool AllElementsEqualValue(absl::Span data, NativeT value) { for (int64 i = 0; i < data.size(); ++i) { if (data[i] != value) { @@ -1691,7 +1677,62 @@ bool LiteralBase::IsAllFirst() const { }); } -bool LiteralBase::IsZero(tensorflow::gtl::ArraySlice indices) const { +bool LiteralBase::IsR1Iota() const { + if (!ShapeUtil::IsArray(shape())) { + return false; + } + + if (ShapeUtil::Rank(shape()) != 1) { + return false; + } + + auto is_iota_at_idx = [&](const int64 idx) { + switch (shape().element_type()) { + case U8: + return Get({idx}) == idx; + case U16: + return Get({idx}) == idx; + case U32: + return Get({idx}) == idx; + case U64: + return Get({idx}) == idx; + case S8: + return Get({idx}) == idx; + case S16: + return Get({idx}) == idx; + case S32: + return Get({idx}) == idx; + case S64: + return Get({idx}) == idx; + case F32: + return Get({idx}) == idx; + case F64: + return Get({idx}) == idx; + case F16: + return Get({idx}) == static_cast(idx); + case BF16: + return Get({idx}) == static_cast(idx); + case C64: + return Get({idx}) == complex64(idx, 0.0f); + case PRED: + return Get({idx}) == idx; + // token, opaque, tuple, etc. are all not iota. + default: + return false; + } + }; + + const int64 elements = ShapeUtil::ElementsIn(shape()); + for (int64 idx = 0; idx < elements; ++idx) { + if (!is_iota_at_idx(idx)) { + return false; + } + } + + return true; +} + +bool LiteralBase::IsZero(absl::Span indices) const { CHECK(ShapeUtil::IsArray(shape())); switch (shape().element_type()) { case U8: @@ -1727,7 +1768,7 @@ namespace { template void CopyToRepeatedField(RepeatedFieldT* dest, - const tensorflow::gtl::ArraySlice src) { + const absl::Span src) { *dest = RepeatedFieldT(src.begin(), src.end()); } @@ -1739,6 +1780,10 @@ void LiteralBase::Piece::WriteToProto(LiteralProto* proto) const { case PRED: CopyToRepeatedField(proto->mutable_preds(), data()); break; + case S8: + proto->set_s8s(static_cast(data().data()), + element_count()); + break; case U8: proto->set_u8s(static_cast(data().data()), element_count()); @@ -1805,7 +1850,7 @@ void* LiteralBase::Piece::untyped_data() { namespace { template -Status CopyFromRepeatedField(tensorflow::gtl::MutableArraySlice dest, +Status CopyFromRepeatedField(absl::Span dest, const RepeatedFieldT& src) { if (dest.size() != src.size()) { return InvalidArgument( @@ -1825,10 +1870,33 @@ Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) { TF_RET_CHECK(LayoutUtil::HasLayout(proto.shape())); TF_RET_CHECK(ShapeUtil::Equal(proto.shape(), subshape())); + if (LayoutUtil::IsSparseArray(subshape())) { + // Compute the number of elements (indices) in the sparse shape and reserve + // the necessary space in spare_indices. + TF_RET_CHECK(ShapeUtil::Rank(subshape()) != 0) + << "Scalar shapes cannot be sparse"; + TF_RET_CHECK(proto.sparse_indices_size() % ShapeUtil::Rank(subshape()) == 0) + << "Unexpected number of indices in proto (" + << proto.sparse_indices_size() << ") for shape of rank " + << ShapeUtil::Rank(subshape()); + const int64 index_count = + proto.sparse_indices_size() / ShapeUtil::Rank(subshape()); + sparse_indices()->Resize(index_count); + + // Copy the indices from the proto into the SparseIndexArray object. + TF_RETURN_IF_ERROR(CopyFromRepeatedField(sparse_indices()->mutable_data(), + proto.sparse_indices())); + } + switch (subshape().element_type()) { case PRED: TF_RETURN_IF_ERROR(CopyFromRepeatedField(data(), proto.preds())); break; + case S8: { + auto s8_data = data(); + TF_RET_CHECK(proto.s8s().size() == s8_data.size()); + std::copy(proto.s8s().begin(), proto.s8s().end(), s8_data.begin()); + } break; case U8: { auto u8_data = data(); TF_RET_CHECK(proto.u8s().size() == u8_data.size()); @@ -1877,11 +1945,11 @@ Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) { } } break; case TUPLE: - LOG(FATAL) << "Should not be called on tuple shapes: " - << ShapeUtil::HumanString(subshape()); - break; + return InvalidArgument("Should not be called on tuple shapes: %s", + ShapeUtil::HumanString(subshape())); default: - LOG(FATAL) << "Unhandled primitive type " << subshape().element_type(); + return InvalidArgument("Is called on unsupported shape: %s", + ShapeUtil::HumanString(subshape())); } return Status::OK(); } @@ -2075,8 +2143,8 @@ BorrowingLiteral::BorrowingLiteral(const char* src_buf_ptr, const Shape& shape) root_piece_.set_subshape(shape_.get()); } -BorrowingLiteral::BorrowingLiteral( - tensorflow::gtl::ArraySlice src_buf_ptrs, const Shape& shape) +BorrowingLiteral::BorrowingLiteral(absl::Span src_buf_ptrs, + const Shape& shape) : LiteralBase(), shape_(absl::make_unique(shape)) { CHECK(ShapeUtil::IsTuple(*shape_)); CHECK(!ShapeUtil::IsNestedTuple(*shape_)); diff --git a/tensorflow/compiler/xla/literal.h b/tensorflow/compiler/xla/literal.h index aad435ed5b288176ebada8d1bcf1cd0239e0de68..3cd3541fe1596600b4f0b43e3011e1f0322ac8fe 100644 --- a/tensorflow/compiler/xla/literal.h +++ b/tensorflow/compiler/xla/literal.h @@ -27,6 +27,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/array4d.h" @@ -41,7 +42,6 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/bitmap.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/protobuf.h" @@ -70,13 +70,12 @@ class LiteralBase { // Serialize to proto. LiteralProto ToProto() const; - // Returns an ArraySlice of the array for this literal for the given NativeT + // Returns a Span of the array for this literal for the given NativeT // (e.g., float). CHECKs if the subshape of the literal at the given // ShapeIndex is not array. See primitive_util.h for the mapping from XLA type // to native type. template - tensorflow::gtl::ArraySlice data( - const ShapeIndex& shape_index = {}) const; + absl::Span data(const ShapeIndex& shape_index = {}) const; // Returns a const pointer to the sparse index array. Returns nullptr if the // literal is not a sparse array. @@ -100,12 +99,12 @@ class LiteralBase { // Gets an element in the literal at the given index. The multi_index is // CHECKed against the dimension sizes. template - NativeT Get(tensorflow::gtl::ArraySlice multi_index, + NativeT Get(absl::Span multi_index, const ShapeIndex& shape_index) const; // Overloads of Get for array literals. CHECKs if the literal is not // array-shaped and dense. template - NativeT Get(tensorflow::gtl::ArraySlice multi_index) const; + NativeT Get(absl::Span multi_index) const; // Returns the element value at index (0, ..., 0), however many zeroes are // required for that index. @@ -114,7 +113,7 @@ class LiteralBase { // As Get(), but determines the correct type and converts the value // into text. - string GetAsString(tensorflow::gtl::ArraySlice multi_index, + string GetAsString(absl::Span multi_index, const ShapeIndex& shape_index = {}) const; // As GetSparseElement(), but determines the correct type and converts the // value into text. @@ -122,14 +121,13 @@ class LiteralBase { const ShapeIndex& shape_index = {}) const; // As Get(), but determines the correct type and converts the value into // int64. This literal must be an array. - StatusOr GetIntegralAsS64( - tensorflow::gtl::ArraySlice multi_index) const; + StatusOr GetIntegralAsS64(absl::Span multi_index) const; // Returns the multi-index of the element in a sparse literal at the given // sparse element number. The sparse element number is the position with in // the sparse array's list of (index, value) pairs, and is checked against the // total number of (index, value) pairs in the sparse array. - tensorflow::gtl::ArraySlice GetSparseIndex( + absl::Span GetSparseIndex( int64 sparse_element_number, const ShapeIndex& shape_index = {}) const; // Returns the value of the element in a sparse literal at the given sparse @@ -150,12 +148,12 @@ class LiteralBase { // // This literal must have a dense layout. void EachCellAsString( - const std::function indices, + const std::function indices, const string& value)>& per_cell) const; template - void EachCell(std::function indices, - NativeT value)> - per_cell) const; + void EachCell( + std::function indices, NativeT value)> + per_cell) const; // Returns whether every element in this literal is equal to value. // @@ -195,13 +193,20 @@ class LiteralBase { // Literal consists entirely of the first element of the literal. bool IsAllFirst() const; + // Literal consists entirely of an iota. + bool IsR1Iota() const; + // Returns whether this literal is zero at the specified index. This literal // must be an array with a dense layout. - bool IsZero(tensorflow::gtl::ArraySlice indices) const; + bool IsZero(absl::Span indices) const; // Returns the count of the elements in the array at the given shape index in // this literal. int64 element_count(const ShapeIndex& index = {}) const { + if (index.empty()) { + // Common case, avoid GetSubshape(). + return ShapeUtil::ElementsIn(shape()); + } return ShapeUtil::ElementsIn(ShapeUtil::GetSubshape(shape(), index)); } @@ -216,31 +221,20 @@ class LiteralBase { // Converts this literal to the given shape. Returns an error is the // conversion is not possible. - // - // round_f32_to_bf16: if true, converting F32 elements to BF16 uses rounding - // instead of truncation; otherwise, truncation is used. - // - // TODO(b/69266521): remove the round_to_bfloat16 flag when rounding becomes - // the default behavior. - StatusOr> ConvertToShape( - const Shape& dest_shape, bool round_f32_to_bf16 = false) const; + StatusOr ConvertToShape(const Shape& dest_shape) const; // Converts this literal to another primitive type using a bitcast // conversion. The to and from primitive types must have the same bit // width. Returns an error if the conversion is not possible. This literal // must be array-shaped. - StatusOr> BitcastConvert( - PrimitiveType primitive_dest_type) const; + StatusOr BitcastConvert(PrimitiveType primitive_dest_type) const; // Converts this literal to another primitive type. Returns an error if the // conversion is not possible. This literal must be array-shaped. - StatusOr> Convert( - PrimitiveType primitive_dest_type) const; + StatusOr Convert(PrimitiveType primitive_dest_type) const; - // Clones the underlying buffers into a new Literal, or new - // std::unique_ptr. + // Clones the underlying buffers into a new Literal. Literal Clone() const; - std::unique_ptr CloneToUnique() const; // TODO(b/67651157): The methods below which perform computation on Literals // (Reshape, Slice, etc) should be moved elsewhere, and perhaps combined with @@ -258,25 +252,23 @@ class LiteralBase { // Note: this is useful when the client wants to ensure that a value placed in // the XLA allocation tracker has a particular layout; for efficiency // purposes or avoiding unimplemented operation/layout combinations. - std::unique_ptr Relayout(const Layout& new_layout, - const ShapeIndex& shape_index = {}) const; + Literal Relayout(const Layout& new_layout, + const ShapeIndex& shape_index = {}) const; // An overload of Relayout which changes the layout of the entire shape rather // than being limited to a single array within the shape. - std::unique_ptr Relayout(const Shape& shape_with_layout) const; + Literal Relayout(const Shape& shape_with_layout) const; // Creates a new literal by reshaping this literal to have the given // dimensions. The total number of elements must not change; The // implementation currently only supports monotonic dim0-major layouts. // This literal must be an array. - StatusOr> Reshape( - tensorflow::gtl::ArraySlice dimensions) const; + StatusOr Reshape(absl::Span dimensions) const; // Creates a new literal by broadcasting this literal with `dimensions` to // yield a literal of shape `result_shape`. - StatusOr> Broadcast( - const Shape& result_shape, - tensorflow::gtl::ArraySlice dimensions) const; + StatusOr Broadcast(const Shape& result_shape, + absl::Span dimensions) const; // Creates a new literal by reordering the dimensions of this literal. // The given `permutation` must be a permutation of the dimension numbers @@ -285,8 +277,7 @@ class LiteralBase { // For example, a transpose call on a literal of shape [3 x 8 x 4] and // `permutation` = {2, 0, 1} returns a new literal of shape [4 x 3 x 8]. // This literal must be an array. - std::unique_ptr Transpose( - tensorflow::gtl::ArraySlice permutation) const; + Literal Transpose(absl::Span permutation) const; // Creates a sub-array from this literal by extracting the indices // [start_index, limit_index) of each dimension. The result literal has the @@ -294,16 +285,15 @@ class LiteralBase { // start_indices and limit_indices must be the rank of the literal, and the // indices follow the order of the dimensions. // This literal must be an array. - std::unique_ptr Slice( - tensorflow::gtl::ArraySlice start_indices, - tensorflow::gtl::ArraySlice limit_indices) const; + Literal Slice(absl::Span start_indices, + absl::Span limit_indices) const; // Creates a literal with a prepended dimension with bound "times"; e.g. a // f32[3x2] with times=4 will produce a f32[4x3x2] with the 3x2 from this // literal replicated four times. // This literal must be an array. template - std::unique_ptr Replicate(int64 times) const; + Literal Replicate(int64 times) const; // Creates a new Literal object with the shape specified as parameter. // The content of the literal values is the default value of the primitive @@ -314,7 +304,7 @@ class LiteralBase { // initialization, then reinitialization. Conside if a call to // absl::make_unique(shape), followed by the call to // MutableLiteralBase::Populate can be used instead. - static std::unique_ptr CreateFromShape(const Shape& shape); + static Literal CreateFromShape(const Shape& shape); protected: // A data structure representing a subshape at a particular ShapeIndex within @@ -325,9 +315,9 @@ class LiteralBase { // Returns the buffer holding the array data for this piece as an array // slice. This piece must be array-shaped. template - tensorflow::gtl::ArraySlice data() const; + absl::Span data() const; template - tensorflow::gtl::MutableArraySlice data(); + absl::Span data(); // Returns the buffer holding the array data for this piece as a void*. This // piece must be array-shaped. @@ -338,9 +328,9 @@ class LiteralBase { // is CHECKed against the dimension sizes of the array. This piece must be // array-shaped. template - NativeT Get(tensorflow::gtl::ArraySlice index) const; + NativeT Get(absl::Span index) const; template - void Set(tensorflow::gtl::ArraySlice index, NativeT value); + void Set(absl::Span index, NativeT value); // Gets/sets the buffer holding the array data. char* buffer() const { return buffer_; } @@ -541,9 +531,8 @@ class LiteralBase { private: template - std::unique_ptr SliceInternal( - const Shape& result_shape, - tensorflow::gtl::ArraySlice start_indices) const; + Literal SliceInternal(const Shape& result_shape, + absl::Span start_indices) const; }; // Abstract base class representing a mutable literal in XLA. @@ -551,13 +540,12 @@ class MutableLiteralBase : public LiteralBase { public: virtual ~MutableLiteralBase() = 0; - // Returns a MutableArraySlice view of the array for this literal for the + // Returns a Span view of the array for this literal for the // given NativeT (e.g., float). CHECKs if the subshape of the literal at the // given ShapeIndex is not array. See primitive_util.h for the mapping from // XLA type to native type. template - tensorflow::gtl::MutableArraySlice data( - const ShapeIndex& shape_index = {}); + absl::Span data(const ShapeIndex& shape_index = {}); // Unhide const method from parent class. using LiteralBase::data; @@ -584,8 +572,7 @@ class MutableLiteralBase : public LiteralBase { // are populated. template void PopulateSparse(SparseIndexArray indices, - tensorflow::gtl::ArraySlice values, - bool sort = true); + absl::Span values, bool sort = true); // Copy values from 'src_literal' rooted at 'src_shape_index' into this // literal rooted at 'dest_shape_index'. The subshape of this literal rooted @@ -606,39 +593,38 @@ class MutableLiteralBase : public LiteralBase { // corresponding base indices being 0. // This literal and 'src_literal' must be arrays. Status CopySliceFrom(const LiteralSlice& src_literal, - tensorflow::gtl::ArraySlice src_base, - tensorflow::gtl::ArraySlice dest_base, - tensorflow::gtl::ArraySlice copy_size); + absl::Span src_base, + absl::Span dest_base, + absl::Span copy_size); // Copies one element from src_literal[src_index] to (*this)[dest_index]. Status CopyElementFrom(const LiteralSlice& src_literal, - tensorflow::gtl::ArraySlice src_index, - tensorflow::gtl::ArraySlice dest_index); + absl::Span src_index, + absl::Span dest_index); // Sets an element in the literal at the given index. The multi_index is // CHECKed against the dimension sizes. template - void Set(tensorflow::gtl::ArraySlice multi_index, - const ShapeIndex& shape_index, NativeT value); + void Set(absl::Span multi_index, const ShapeIndex& shape_index, + NativeT value); // Overloads of Set for array literals. CHECKs if the literal is not // array-shaped and dense. template - void Set(tensorflow::gtl::ArraySlice multi_index, NativeT value); + void Set(absl::Span multi_index, NativeT value); // Appends the given element to the literal. If the elements are not appended // in sorted order, then SortSparseElements should be called before calling // other methods. This literal must have a sparse layout. template - void AppendSparseElement(tensorflow::gtl::ArraySlice multi_index, - NativeT value, const ShapeIndex& shape_index = {}); + void AppendSparseElement(absl::Span multi_index, NativeT value, + const ShapeIndex& shape_index = {}); // Sorts the elements in a sparse array. void SortSparseElements(const ShapeIndex& shape_index = {}); // As Set(), but truncates `value` to the literal element type before storing. // This literal must be an array. - Status SetIntegralAsS64(tensorflow::gtl::ArraySlice multi_index, - int64 value); + Status SetIntegralAsS64(absl::Span multi_index, int64 value); // Populate this literal with the given values. Examples: // @@ -653,7 +639,7 @@ class MutableLiteralBase : public LiteralBase { // example, in the call above to literal.PopulateR2(), 'literal' must be a 2x2 // array of S32. template - void PopulateR1(tensorflow::gtl::ArraySlice values); + void PopulateR1(absl::Span values); void PopulateR1(const tensorflow::core::Bitmap& values); template void PopulateR2(std::initializer_list> values); @@ -670,7 +656,7 @@ class MutableLiteralBase : public LiteralBase { // in this literal object. // // generator must be a callable of the type - // NativeT(tensorflow::gtl::ArraySlice indexes) or compatible. + // NativeT(absl::Span indexes) or compatible. // // This literal must have a dense layout. template @@ -690,12 +676,10 @@ class MutableLiteralBase : public LiteralBase { // moved into the tuple elements of a new tuple-shaped Literal which is // returned. Upon return, each of the Literals in 'elements' is set to a nil // shape (empty tuple). - static Literal MoveIntoTuple( - tensorflow::gtl::MutableArraySlice elements); + static Literal MoveIntoTuple(absl::Span elements); // Serialize from a proto. - static StatusOr> CreateFromProto( - const LiteralProto& proto); + static StatusOr CreateFromProto(const LiteralProto& proto); protected: // Returns the piece at the given ShapeIndex. @@ -709,20 +693,20 @@ class MutableLiteralBase : public LiteralBase { // arguments one by one. template Status CopySliceFromInternal(const LiteralBase& src_literal, - tensorflow::gtl::ArraySlice src_base, - tensorflow::gtl::ArraySlice dest_base, - tensorflow::gtl::ArraySlice copy_size); + absl::Span src_base, + absl::Span dest_base, + absl::Span copy_size); // Utility structure which is used to create the optimal configuration for // a ShapeUtil::ForEachIndex() scan across two literals. struct StrideConfig { StrideConfig(const Shape& source_shape, const Shape& dest_shape, - tensorflow::gtl::ArraySlice dimensions); + absl::Span dimensions); // The dimensions of the stride operation. Essentially every dimension // will be iterated from base[i] to base[i]+dimensions[i], in step[i] // steps. - tensorflow::gtl::ArraySlice dimensions; + absl::Span dimensions; DimensionVector base; DimensionVector step; int64 minor_dimension = 0; @@ -851,7 +835,7 @@ class BorrowingLiteral : public LiteralBase { // This constructor is only used for array shapes. BorrowingLiteral(const char* src_buf_ptr, const Shape& shape); // Similar as above, except to be used for constructing non-nested tuples. - BorrowingLiteral(tensorflow::gtl::ArraySlice src_buf_ptrs, + BorrowingLiteral(absl::Span src_buf_ptrs, const Shape& shape); // TODO(b/79707221): adding constructors for nested tuples as well. @@ -871,41 +855,40 @@ class BorrowingLiteral : public LiteralBase { }; template -tensorflow::gtl::ArraySlice LiteralBase::Piece::data() const { - CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape()); - CHECK_EQ(subshape().element_type(), - primitive_util::NativeToPrimitiveType()) +absl::Span LiteralBase::Piece::data() const { + DCHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape()); + DCHECK_EQ(subshape().element_type(), + primitive_util::NativeToPrimitiveType()) << "Attempting to access " << PrimitiveType_Name(primitive_util::NativeToPrimitiveType()) << " type, but literal element type is " << PrimitiveType_Name(subshape().element_type()); - return tensorflow::gtl::ArraySlice( - reinterpret_cast(buffer()), element_count()); + return absl::Span(reinterpret_cast(buffer()), + element_count()); } template -tensorflow::gtl::MutableArraySlice LiteralBase::Piece::data() { - CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape()); - CHECK_EQ(subshape().element_type(), - primitive_util::NativeToPrimitiveType()) +absl::Span LiteralBase::Piece::data() { + DCHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape()); + DCHECK_EQ(subshape().element_type(), + primitive_util::NativeToPrimitiveType()) << "Attempting to access " << PrimitiveType_Name(primitive_util::NativeToPrimitiveType()) << " type, but literal element type is " << PrimitiveType_Name(subshape().element_type()); - return tensorflow::gtl::MutableArraySlice( - reinterpret_cast(buffer()), element_count()); + return absl::Span(reinterpret_cast(buffer()), + element_count()); } template -NativeT LiteralBase::Piece::Get( - tensorflow::gtl::ArraySlice multi_index) const { +NativeT LiteralBase::Piece::Get(absl::Span multi_index) const { CHECK(LayoutUtil::IsDenseArray(subshape())); return data()[IndexUtil::MultidimensionalIndexToLinearIndex( subshape(), multi_index)]; } template -void LiteralBase::Piece::Set(tensorflow::gtl::ArraySlice multi_index, +void LiteralBase::Piece::Set(absl::Span multi_index, NativeT value) { CHECK(LayoutUtil::IsDenseArray(subshape())); data()[IndexUtil::MultidimensionalIndexToLinearIndex( @@ -913,39 +896,37 @@ void LiteralBase::Piece::Set(tensorflow::gtl::ArraySlice multi_index, } template -tensorflow::gtl::ArraySlice LiteralBase::data( +absl::Span LiteralBase::data( const ShapeIndex& shape_index) const { return piece(shape_index).data(); } template -tensorflow::gtl::MutableArraySlice MutableLiteralBase::data( - const ShapeIndex& shape_index) { +absl::Span MutableLiteralBase::data(const ShapeIndex& shape_index) { return piece(shape_index).data(); } template -inline NativeT LiteralBase::Get(tensorflow::gtl::ArraySlice multi_index, +inline NativeT LiteralBase::Get(absl::Span multi_index, const ShapeIndex& shape_index) const { return piece(shape_index).Get(multi_index); } template -inline NativeT LiteralBase::Get( - tensorflow::gtl::ArraySlice multi_index) const { +inline NativeT LiteralBase::Get(absl::Span multi_index) const { return root_piece().Get(multi_index); } template -inline void MutableLiteralBase::Set( - tensorflow::gtl::ArraySlice multi_index, - const ShapeIndex& shape_index, NativeT value) { +inline void MutableLiteralBase::Set(absl::Span multi_index, + const ShapeIndex& shape_index, + NativeT value) { return piece(shape_index).Set(multi_index, value); } template -inline void MutableLiteralBase::Set( - tensorflow::gtl::ArraySlice multi_index, NativeT value) { +inline void MutableLiteralBase::Set(absl::Span multi_index, + NativeT value) { return root_piece().Set(multi_index, value); } @@ -964,7 +945,7 @@ NativeT LiteralBase::GetSparseElement(int64 sparse_element_number, template void MutableLiteralBase::AppendSparseElement( - tensorflow::gtl::ArraySlice multi_index, NativeT value, + absl::Span multi_index, NativeT value, const ShapeIndex& shape_index) { Piece& p = piece(shape_index); const Shape& subshape = p.subshape(); @@ -980,8 +961,7 @@ void MutableLiteralBase::AppendSparseElement( template void LiteralBase::EachCell( - std::function indices, - NativeT value)> + std::function indices, NativeT value)> per_cell) const { if (ShapeUtil::IsZeroElementArray(shape())) { return; @@ -989,12 +969,11 @@ void LiteralBase::EachCell( std::vector indices(ShapeUtil::Rank(shape()), 0); do { per_cell(indices, Get(indices)); - } while (IndexUtil::BumpIndices(shape(), &indices)); + } while (IndexUtil::BumpIndices(shape(), absl::MakeSpan(indices))); } template -inline void MutableLiteralBase::PopulateR1( - tensorflow::gtl::ArraySlice values) { +inline void MutableLiteralBase::PopulateR1(absl::Span values) { CHECK(ShapeUtil::IsArray(shape())); CHECK_EQ(ShapeUtil::Rank(shape()), 1); CHECK_EQ(ShapeUtil::ElementsIn(shape()), values.size()); @@ -1039,8 +1018,9 @@ void MutableLiteralBase::PopulateFromArray(const Array& values) { for (int dim = 0; dim < values.num_dimensions(); ++dim) { CHECK_EQ(values.dim(dim), shape().dimensions(dim)); } - values.Each([this](tensorflow::gtl::ArraySlice indices, - NativeT value) { this->Set(indices, value); }); + values.Each([this](absl::Span indices, NativeT value) { + this->Set(indices, value); + }); } template @@ -1059,9 +1039,9 @@ void MutableLiteralBase::PopulateR4FromArray4D(const Array4D& values) { } template -void MutableLiteralBase::PopulateSparse( - SparseIndexArray indices, tensorflow::gtl::ArraySlice values, - bool sort) { +void MutableLiteralBase::PopulateSparse(SparseIndexArray indices, + absl::Span values, + bool sort) { CHECK(LayoutUtil::IsSparseArray(shape())); int rank = ShapeUtil::Rank(shape()); CHECK_EQ(indices.rank(), rank); @@ -1071,7 +1051,7 @@ void MutableLiteralBase::PopulateSparse( CHECK_LE(num_elements, max_elements); CHECK_EQ(num_elements, indices.index_count()); auto root_data = root_piece().data(); - // Piece::data() returns an ArraySlice of size equal to the number of indices + // Piece::data() returns a Span of size equal to the number of indices // in the SparseIndexArray. So there is no need to adjust the size of the data // here. It is enough to just copy the incoming values into the data buffer. std::copy(values.begin(), values.end(), root_data.begin()); @@ -1091,14 +1071,14 @@ Status MutableLiteralBase::PopulateInternal(const FnType& generator, TF_RET_CHECK(LayoutUtil::IsDenseArray(this_shape)); TF_RET_CHECK(this_shape.element_type() == primitive_util::NativeToPrimitiveType()); - tensorflow::gtl::MutableArraySlice literal_data = data(); + absl::Span literal_data = data(); if (rank > 0) { StrideConfig stride_config(this_shape, this_shape, AsInt64Slice(this_shape.dimensions())); int64 minor_dimension_size = ShapeUtil::GetDimension(this_shape, stride_config.minor_dimension); - auto init_function = [&](tensorflow::gtl::ArraySlice indexes) { + auto init_function = [&](absl::Span indexes) { DimensionVector minor_scan_indexes(rank, 0); const int64 index = IndexUtil::MultidimensionalIndexToLinearIndex(shape(), indexes); @@ -1116,7 +1096,7 @@ Status MutableLiteralBase::PopulateInternal(const FnType& generator, ShapeUtil::ForEachIndex( this_shape, stride_config.base, stride_config.dimensions, stride_config.step, - [&init_function](tensorflow::gtl::ArraySlice indexes) { + [&init_function](absl::Span indexes) { init_function(indexes); return true; }); @@ -1148,27 +1128,26 @@ void MutableLiteralBase::PopulateWithValue(NativeT value) { } template -std::unique_ptr LiteralBase::Replicate(int64 times) const { +Literal LiteralBase::Replicate(int64 times) const { DimensionVector bounds = {times}; bounds.reserve(shape().dimensions_size() + 1); for (int64 bound : shape().dimensions()) { bounds.push_back(bound); } - auto literal = absl::make_unique( - ShapeUtil::MakeShape(shape().element_type(), bounds)); - int64 elements = ShapeUtil::ElementsIn(literal->shape()); + Literal literal(ShapeUtil::MakeShape(shape().element_type(), bounds)); + int64 elements = ShapeUtil::ElementsIn(literal.shape()); if (elements == 0) { return literal; } DimensionVector output_indices(bounds.size(), 0); - tensorflow::gtl::ArraySlice input_indices = output_indices; + absl::Span input_indices = output_indices; input_indices.remove_prefix(1); bool done = false; while (!done) { const auto element = Get(input_indices); - literal->Set(output_indices, element); + literal.Set(output_indices, element); done = true; for (int n = 0; n < output_indices.size(); ++n) { diff --git a/tensorflow/compiler/xla/literal_comparison.cc b/tensorflow/compiler/xla/literal_comparison.cc index 67a69c240321779503bd3e1e20cfbaed842ee034..3d8725ed7051cafc97987f25a96004fa876dfdd3 100644 --- a/tensorflow/compiler/xla/literal_comparison.cc +++ b/tensorflow/compiler/xla/literal_comparison.cc @@ -20,15 +20,15 @@ limitations under the License. #include #include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/casts.h" #include "tensorflow/core/platform/env.h" using absl::StrAppend; +using absl::StrAppendFormat; using absl::StrCat; -using tensorflow::strings::Appendf; -using tensorflow::strings::Printf; namespace xla { namespace literal_comparison { @@ -38,8 +38,8 @@ namespace { // between the left-hand-side and right-hand-side, by bit-casting to UnsignedT // -- on miscompare, a nice error message is given in the AssertionFailure. template -Status CompareFloatsBitwiseEqual( - FloatT lhs, FloatT rhs, tensorflow::gtl::ArraySlice multi_index) { +Status CompareFloatsBitwiseEqual(FloatT lhs, FloatT rhs, + absl::Span multi_index) { auto ulhs = tensorflow::bit_cast(lhs); auto urhs = tensorflow::bit_cast(rhs); auto lhs_double = static_cast(lhs); @@ -48,9 +48,9 @@ Status CompareFloatsBitwiseEqual( return InvalidArgument( "floating values are not bitwise-equal; and equality testing " "was requested: %s=%g=%a vs %s=%g=%a at array index %s", - StrCat(absl::Hex(ulhs)).c_str(), lhs_double, lhs_double, - StrCat(absl::Hex(urhs)).c_str(), rhs_double, rhs_double, - LiteralUtil::MultiIndexAsString(multi_index).c_str()); + StrCat(absl::Hex(ulhs)), lhs_double, lhs_double, + StrCat(absl::Hex(urhs)), rhs_double, rhs_double, + LiteralUtil::MultiIndexAsString(multi_index)); } return Status::OK(); } @@ -60,43 +60,41 @@ Status CompareFloatsBitwiseEqual( // default gunit implementation). template Status CompareEqual(NativeT lhs, NativeT rhs, - tensorflow::gtl::ArraySlice multi_index) { + absl::Span multi_index) { if (lhs == rhs) { return Status::OK(); } return InvalidArgument( "first mismatch at array index %s:\n expected value: %s\n actual " "value: %s", - LiteralUtil::MultiIndexAsString(multi_index).c_str(), StrCat(lhs).c_str(), - StrCat(rhs).c_str()); + LiteralUtil::MultiIndexAsString(multi_index), StrCat(lhs), StrCat(rhs)); } // Specializations for floating types that do bitwise comparisons when equality // comparison is requested. template <> Status CompareEqual(bfloat16 lhs, bfloat16 rhs, - tensorflow::gtl::ArraySlice multi_index) { + absl::Span multi_index) { return CompareFloatsBitwiseEqual(lhs, rhs, multi_index); } template <> -Status CompareEqual( - Eigen::half lhs, Eigen::half rhs, - tensorflow::gtl::ArraySlice multi_index) { +Status CompareEqual(Eigen::half lhs, Eigen::half rhs, + absl::Span multi_index) { return CompareFloatsBitwiseEqual(lhs, rhs, multi_index); } template <> Status CompareEqual(float lhs, float rhs, - tensorflow::gtl::ArraySlice multi_index) { + absl::Span multi_index) { return CompareFloatsBitwiseEqual(lhs, rhs, multi_index); } template <> Status CompareEqual(double lhs, double rhs, - tensorflow::gtl::ArraySlice multi_index) { + absl::Span multi_index) { return CompareFloatsBitwiseEqual(lhs, rhs, multi_index); } template <> Status CompareEqual(complex64 lhs, complex64 rhs, - tensorflow::gtl::ArraySlice multi_index) { + absl::Span multi_index) { auto res = CompareEqual(lhs.real(), rhs.real(), multi_index); if (!res.ok()) { return res; @@ -109,8 +107,7 @@ Status CompareEqual(complex64 lhs, complex64 rhs, // elements are equal. template Status Equal(LiteralSlice expected, LiteralSlice actual, - tensorflow::gtl::MutableArraySlice multi_index, - int64 dimension) { + absl::Span multi_index, int64 dimension) { if (dimension == expected.shape().dimensions_size()) { NativeT expected_value = expected.Get(multi_index); NativeT actual_value = actual.Get(multi_index); @@ -165,15 +162,26 @@ bool NanMismatch(half expected, half actual, bool relaxed_nans) { static_cast(actual), relaxed_nans); } +// Returns whether the given value is infinity. +template +bool IsInf(NativeT val) { + return std::isinf(val); +} + +template <> +bool IsInf(half val) { + return std::isinf(static_cast(val)); +} + // Converts the given floating-point value to a string. template string FpValueToString(NativeT value) { - return Printf("%8.4g", static_cast(value)); + return absl::StrFormat("%8.4g", static_cast(value)); } template <> string FpValueToString(complex64 value) { - return Printf("%8.4g + %8.4fi", value.real(), value.imag()); + return absl::StrFormat("%8.4g + %8.4fi", value.real(), value.imag()); } // Returns the absolute value of the given floating point value. This function @@ -228,13 +236,12 @@ class NearComparator { } string ToString(const Shape& shape) const { - return Printf( + return absl::StrFormat( "actual %s, expected %s, index %s, rel error %8.3g, abs error %8.3g", - FpValueToString(actual).c_str(), FpValueToString(expected).c_str(), + FpValueToString(actual), FpValueToString(expected), LiteralUtil::MultiIndexAsString( IndexUtil::LinearIndexToMultidimensionalIndex(shape, - linear_index)) - .c_str(), + linear_index)), rel_error, abs_error); } }; @@ -258,7 +265,7 @@ class NearComparator { TF_RETURN_IF_ERROR(EqualShapes(expected_.shape(), actual_.shape())); if (!ShapeUtil::IsArray(expected_.shape())) { return InvalidArgument("Expected array shape; got %s.", - ShapeUtil::HumanString(expected_.shape()).c_str()); + ShapeUtil::HumanString(expected_.shape())); } mismatches_ = Literal(ShapeUtil::ChangeElementType(actual_.shape(), PRED)); @@ -271,7 +278,7 @@ class NearComparator { } else if (!VLOG_IS_ON(1) && miscompare_callback_ != nullptr) { miscompare_callback_(expected_, actual_, mismatches_); } - return InvalidArgument("%s", ErrorMessage().c_str()); + return InvalidArgument("%s", ErrorMessage()); } // Insert the given absolute value into the absolute value bucket vector. The @@ -296,8 +303,7 @@ class NearComparator { } // Insert the given error into the given error bucket vector. - void UpdateErrorBucket( - float error, tensorflow::gtl::MutableArraySlice error_buckets) { + void UpdateErrorBucket(float error, absl::Span error_buckets) { CHECK_EQ(error_buckets.size(), kErrorBucketBounds.size()); for (int i = 0; i < error_buckets.size(); ++i) { if (error >= kErrorBucketBounds[i]) { @@ -308,12 +314,13 @@ class NearComparator { // Compares the two given elements from the expected and actual literals at // the given literal_index and keeps track of various mismatch statistics. - void CompareValues(NativeT expected, NativeT actual, int64 linear_index) { + template + void CompareValues(T expected, T actual, int64 linear_index) { const bool is_nan_mismatch = NanMismatch(expected, actual, error_.relaxed_nans); float abs_error; float rel_error; - if (actual == expected) { + if (CompareEqual(expected, actual, {linear_index}).ok()) { abs_error = 0; rel_error = 0; } else if (is_nan_mismatch) { @@ -324,6 +331,12 @@ class NearComparator { // weak ordering requirement of std containers. abs_error = std::numeric_limits::infinity(); rel_error = std::numeric_limits::infinity(); + } else if (IsInf(expected) || IsInf(actual)) { + // If either the expected or actual value is infinity but not both, + // then both absolute and relative error are regarded as inifity. + CHECK(!CompareEqual(expected, actual, {linear_index}).ok()); + abs_error = std::numeric_limits::infinity(); + rel_error = std::numeric_limits::infinity(); } else { abs_error = FpAbsoluteValue(actual - expected); rel_error = abs_error / FpAbsoluteValue(expected); @@ -337,11 +350,11 @@ class NearComparator { // bound is exceeded and vice versa. if (is_abs_mismatch) { num_abs_mismatches_++; - UpdateErrorBucket(rel_error, &rel_error_buckets_); + UpdateErrorBucket(rel_error, absl::MakeSpan(rel_error_buckets_)); } if (is_rel_mismatch) { num_rel_mismatches_++; - UpdateErrorBucket(abs_error, &abs_error_buckets_); + UpdateErrorBucket(abs_error, absl::MakeSpan(abs_error_buckets_)); } UpdateAbsValueBucket(actual, is_mismatch); @@ -366,15 +379,36 @@ class NearComparator { mismatches_.data()[linear_index] = true; } + // For complex64 types, we compare real and imaginary parts individually. + void CompareValues(complex64 expected, complex64 actual, int64 linear_index) { + bool mismatch = false; + CompareValues(expected.real(), actual.real(), linear_index); + if (mismatches_.data()[linear_index] == true) { + mismatch = true; + // Delay the mismatch count increase for real part, instead increase + // mismatch by 1 for the entire complex number. + num_mismatches_--; + } + CompareValues(expected.imag(), actual.imag(), linear_index); + if (mismatches_.data()[linear_index] == true) { + mismatch = true; + // Delay the mismatch count increase for imag part, instead increase + // mismatch by 1 for the entire complex number. + num_mismatches_--; + } + if (mismatch == true) { + num_mismatches_++; + } + mismatches_.data()[linear_index] = mismatch; + } + // Compares the two literals elementwise. void CompareLiterals() { // Fast path optimization for the case were layouts match. if (LayoutUtil::Equal(actual_.shape().layout(), expected_.shape().layout())) { - tensorflow::gtl::ArraySlice expected_data = - expected_.data(); - tensorflow::gtl::ArraySlice actual_data = - actual_.data(); + absl::Span expected_data = expected_.data(); + absl::Span actual_data = actual_.data(); const int64 len = expected_data.size(); for (int64 i = 0; i < len; ++i) { CompareValues(expected_data[i], actual_data[i], i); @@ -410,23 +444,23 @@ class NearComparator { auto percent_string = [](float a, float b) { float pct = b == 0.0 ? 0.0 : 100.0 * a / b; - return Printf("%0.4f%%", pct); + return absl::StrFormat("%0.4f%%", pct); }; - Appendf(&out, - "\nMismatch count %lld (%s) in shape %s (%lld elements), abs bound " - "%g, rel bound %g\n", - num_mismatches_, - percent_string(num_mismatches_, element_count).c_str(), - ShapeUtil::HumanString(actual_.shape()).c_str(), - ShapeUtil::ElementsIn(actual_.shape()), error_.abs, error_.rel); + StrAppendFormat( + &out, + "\nMismatch count %d (%s) in shape %s (%d elements), abs bound " + "%g, rel bound %g\n", + num_mismatches_, percent_string(num_mismatches_, element_count), + ShapeUtil::HumanString(actual_.shape()), + ShapeUtil::ElementsIn(actual_.shape()), error_.abs, error_.rel); if (num_nan_mismatches_ > 0) { StrAppend(&out, "nan mismatches ", num_nan_mismatches_, "\n"); } - Appendf(&out, "Top relative error mismatches:\n"); + StrAppendFormat(&out, "Top relative error mismatches:\n"); for (auto it = top_rel_mismatches_.rbegin(); it != top_rel_mismatches_.rend(); ++it) { - StrAppend(&out, " ", it->ToString(actual_.shape()).c_str(), "\n"); + StrAppend(&out, " ", it->ToString(actual_.shape()), "\n"); } if (!detailed_message_) { @@ -438,36 +472,37 @@ class NearComparator { for (int i = 0; i < abs_value_buckets_.size(); ++i) { const int64 bucket_size = abs_value_buckets_[i].first; const int64 bucket_mismatches = abs_value_buckets_[i].second; - string mismatch_str = bucket_mismatches > 0 - ? Printf(", mismatches %lld", bucket_mismatches) - : ""; - Appendf(&out, " %-6g <= x < %-6g : %7lld (%9s)%s\n", - kAbsValueBucketBounds[i], kAbsValueBucketBounds[i + 1], - bucket_size, percent_string(bucket_size, element_count).c_str(), - mismatch_str.c_str()); + string mismatch_str = + bucket_mismatches > 0 + ? absl::StrFormat(", mismatches %d", bucket_mismatches) + : ""; + StrAppendFormat(&out, " %-6g <= x < %-6g : %7d (%9s)%s\n", + kAbsValueBucketBounds[i], kAbsValueBucketBounds[i + 1], + bucket_size, percent_string(bucket_size, element_count), + mismatch_str); } auto print_accum_buckets = [&](const string& header, int64 total, - tensorflow::gtl::ArraySlice buckets) { + absl::Span buckets) { StrAppend(&out, header, ":\n"); - Appendf(&out, " < %-6g : %7lld (%s)\n", kErrorBucketBounds[0], - total - buckets[0], - percent_string(total - buckets[0], total).c_str()); + StrAppendFormat(&out, " < %-6g : %7d (%s)\n", kErrorBucketBounds[0], + total - buckets[0], + percent_string(total - buckets[0], total)); CHECK_EQ(buckets.size(), kErrorBucketBounds.size()); for (int i = 0; i < kErrorBucketBounds.size(); ++i) { - Appendf(&out, " >= %-6g : %7lld (%s)\n", kErrorBucketBounds[i], - buckets[i], percent_string(buckets[i], total).c_str()); + StrAppendFormat(&out, " >= %-6g : %7d (%s)\n", kErrorBucketBounds[i], + buckets[i], percent_string(buckets[i], total)); } }; - Appendf(&out, "Elements exceeding abs error bound %g: %lld (%s)\n", - error_.abs, num_abs_mismatches_, - percent_string(num_abs_mismatches_, element_count).c_str()); + StrAppendFormat(&out, "Elements exceeding abs error bound %g: %d (%s)\n", + error_.abs, num_abs_mismatches_, + percent_string(num_abs_mismatches_, element_count)); print_accum_buckets( "Relative error breakdown of elements exceeding abs error bound", num_abs_mismatches_, rel_error_buckets_); - Appendf(&out, "Elements exceeding rel error bound %g: %lld (%s)\n", - error_.rel, num_rel_mismatches_, - percent_string(num_rel_mismatches_, element_count).c_str()); + StrAppendFormat(&out, "Elements exceeding rel error bound %g: %d (%s)\n", + error_.rel, num_rel_mismatches_, + percent_string(num_rel_mismatches_, element_count)); print_accum_buckets( "Absolute error breakdown of elements exceeding rel error bound", num_rel_mismatches_, abs_error_buckets_); @@ -539,40 +574,41 @@ constexpr std::array NearComparator::kErrorBucketBounds; Status EqualHelper(const LiteralSlice& expected, const LiteralSlice& actual) { TF_RETURN_IF_ERROR(EqualShapes(expected.shape(), actual.shape())); std::vector multi_index(expected.shape().dimensions_size(), 0); + auto index = absl::MakeSpan(multi_index); Status result; switch (expected.shape().element_type()) { case PRED: - result = Equal(expected, actual, &multi_index, 0); + result = Equal(expected, actual, index, 0); break; case U8: - result = Equal(expected, actual, &multi_index, 0); + result = Equal(expected, actual, index, 0); break; case S32: - result = Equal(expected, actual, &multi_index, 0); + result = Equal(expected, actual, index, 0); break; case S64: - result = Equal(expected, actual, &multi_index, 0); + result = Equal(expected, actual, index, 0); break; case U32: - result = Equal(expected, actual, &multi_index, 0); + result = Equal(expected, actual, index, 0); break; case U64: - result = Equal(expected, actual, &multi_index, 0); + result = Equal(expected, actual, index, 0); break; case BF16: - result = Equal(expected, actual, &multi_index, 0); + result = Equal(expected, actual, index, 0); break; case F16: - result = Equal(expected, actual, &multi_index, 0); + result = Equal(expected, actual, index, 0); break; case F32: - result = Equal(expected, actual, &multi_index, 0); + result = Equal(expected, actual, index, 0); break; case F64: - result = Equal(expected, actual, &multi_index, 0); + result = Equal(expected, actual, index, 0); break; case C64: - result = Equal(expected, actual, &multi_index, 0); + result = Equal(expected, actual, index, 0); break; case TUPLE: { for (int i = 0; i < ShapeUtil::TupleElementCount(expected.shape()); ++i) { @@ -612,9 +648,9 @@ Status NearHelper(const LiteralSlice& expected, const LiteralSlice& actual, NearHelper(expected_element, actual_element, error, detailed_message, miscompare_callback, element_index); if (!element_result.ok()) { - element_result = InvalidArgument( - "Array at shape index %s, %s", element_index.ToString().c_str(), - element_result.error_message().c_str()); + element_result = InvalidArgument("Array at shape index %s, %s", + element_index.ToString(), + element_result.error_message()); if (return_status.ok()) { return_status = element_result; } else { @@ -627,10 +663,10 @@ Status NearHelper(const LiteralSlice& expected, const LiteralSlice& actual, // Emit a top-level error message containing the top-level shape in case // of mismatch. int64 total_elements = RecursiveElementCount(actual.shape()); - return_status = InvalidArgument( - "\nMismatches in shape %s (%lld elements):\n%s", - ShapeUtil::HumanString(actual.shape()).c_str(), total_elements, - return_status.error_message().c_str()); + return_status = + InvalidArgument("\nMismatches in shape %s (%d elements):\n%s", + ShapeUtil::HumanString(actual.shape()), + total_elements, return_status.error_message()); } return return_status; } @@ -674,14 +710,14 @@ Status NearHelper(const LiteralSlice& expected, const LiteralSlice& actual, Status EqualShapes(const Shape& expected, const Shape& actual) { if (expected.element_type() != actual.element_type()) { return InvalidArgument("element type mismatch, want: %s got %s", - ShapeUtil::HumanString(expected).c_str(), - ShapeUtil::HumanString(actual).c_str()); + ShapeUtil::HumanString(expected), + ShapeUtil::HumanString(actual)); } if (ShapeUtil::IsTuple(expected)) { if (ShapeUtil::TupleElementCount(expected) != ShapeUtil::TupleElementCount(actual)) { return InvalidArgument( - "want tuple element count: %lld got tuple element count: %lld", + "want tuple element count: %d got tuple element count: %d", ShapeUtil::TupleElementCount(expected), ShapeUtil::TupleElementCount(actual)); } @@ -695,14 +731,13 @@ Status EqualShapes(const Shape& expected, const Shape& actual) { } else if (ShapeUtil::IsArray(expected)) { if (ShapeUtil::Rank(expected) != ShapeUtil::Rank(actual)) { return InvalidArgument("want rank of %s got rank of %s", - ShapeUtil::HumanString(expected).c_str(), - ShapeUtil::HumanString(actual).c_str()); + ShapeUtil::HumanString(expected), + ShapeUtil::HumanString(actual)); } if (expected.element_type() != actual.element_type()) { - return InvalidArgument( - "mismatch in primitive type %s vs %s", - PrimitiveType_Name(expected.element_type()).c_str(), - PrimitiveType_Name(actual.element_type()).c_str()); + return InvalidArgument("mismatch in primitive type %s vs %s", + PrimitiveType_Name(expected.element_type()), + PrimitiveType_Name(actual.element_type())); } if (expected.dimensions_size() != actual.dimensions_size()) { return InvalidArgument("want dimensions_size %d got dimensions_size %d", @@ -713,8 +748,7 @@ Status EqualShapes(const Shape& expected, const Shape& actual) { if (expected.dimensions(i) != actual.dimensions(i)) { return InvalidArgument( "mismatch in dimension #%d expected: %s actual: %s", i, - ShapeUtil::HumanString(expected).c_str(), - ShapeUtil::HumanString(actual).c_str()); + ShapeUtil::HumanString(expected), ShapeUtil::HumanString(actual)); } } } @@ -733,9 +767,8 @@ Status EmitLiteralsInErrorMessage(const Status& result, return result; } return InvalidArgument("%s\n\nExpected literal:\n%s\n\nActual literal:\n%s", - result.error_message().c_str(), - ToStringTruncated(expected).c_str(), - ToStringTruncated(actual).c_str()); + result.error_message(), ToStringTruncated(expected), + ToStringTruncated(actual)); } } // namespace diff --git a/tensorflow/compiler/xla/literal_test.cc b/tensorflow/compiler/xla/literal_test.cc index aef87e46d83fcd927572c82309b677b3479bab1f..dd5b54e4c99998f676419cf98a3da16593338829 100644 --- a/tensorflow/compiler/xla/literal_test.cc +++ b/tensorflow/compiler/xla/literal_test.cc @@ -36,7 +36,6 @@ limitations under the License. namespace xla { namespace { -using tensorflow::gtl::ArraySlice; using ::testing::ElementsAre; using ::testing::HasSubstr; @@ -93,48 +92,48 @@ class LiteralUtilTest : public ::testing::Test { Layout layout_r3_dim0minor_; Layout layout_r4_dim0major_; Layout layout_r4_dim0minor_; - std::unique_ptr literal_r4_2x2x3x3_dim0major_; - std::unique_ptr literal_r4_2x2x3x3_dim0minor_; + Literal literal_r4_2x2x3x3_dim0major_; + Literal literal_r4_2x2x3x3_dim0minor_; }; TEST_F(LiteralUtilTest, LiteralScalarToString) { auto true_lit = LiteralUtil::CreateR0(true); - ASSERT_EQ("true", true_lit->ToString()); + EXPECT_EQ("true", true_lit.ToString()); auto false_lit = LiteralUtil::CreateR0(false); - ASSERT_EQ("false", false_lit->ToString()); + EXPECT_EQ("false", false_lit.ToString()); auto u32_lit = LiteralUtil::CreateR0(42); - ASSERT_EQ("42", u32_lit->ToString()); + EXPECT_EQ("42", u32_lit.ToString()); auto s32_lit = LiteralUtil::CreateR0(-999); - ASSERT_EQ("-999", s32_lit->ToString()); + EXPECT_EQ("-999", s32_lit.ToString()); auto f32_lit = LiteralUtil::CreateR0(3.14f); - ASSERT_EQ("3.14", f32_lit->ToString()); + EXPECT_EQ("3.14", f32_lit.ToString()); auto f16_lit = LiteralUtil::CreateR0(static_cast(0.5f)); - ASSERT_EQ("0.5", f16_lit->ToString()); + EXPECT_EQ("0.5", f16_lit.ToString()); auto c64_lit = LiteralUtil::CreateR0({3.14f, 2.78f}); - ASSERT_EQ("(3.14, 2.78)", c64_lit->ToString()); + EXPECT_EQ("(3.14, 2.78)", c64_lit.ToString()); auto bf16_lit = LiteralUtil::CreateR0(static_cast(0.5f)); - ASSERT_EQ("0.5", bf16_lit->ToString()); + EXPECT_EQ("0.5", bf16_lit.ToString()); - // 3.14 will be truncated to 3.125 in bfloat16 format. + // 3.14 will be rounded to 3.14062 in bfloat16 format. auto bf16_lit_truncated = LiteralUtil::CreateR0(static_cast(3.14f)); - ASSERT_EQ("3.125", bf16_lit_truncated->ToString()); + ASSERT_EQ("3.14062", bf16_lit_truncated.ToString()); auto bf16_lit_truncated2 = LiteralUtil::CreateR0(static_cast(9.001f)); - ASSERT_EQ("9", bf16_lit_truncated2->ToString()); + EXPECT_EQ("9", bf16_lit_truncated2.ToString()); } TEST_F(LiteralUtilTest, LiteralVectorToString) { auto pred_vec = LiteralUtil::CreateR1({true, false, true}); - ASSERT_EQ("{101}", pred_vec->ToString()); + EXPECT_EQ("{101}", pred_vec.ToString()); } TEST_F(LiteralUtilTest, R2ToString) { @@ -144,7 +143,7 @@ TEST_F(LiteralUtilTest, R2ToString) { { 3, 4 }, { 5, 6 } })"; - ASSERT_EQ(expected, literal->ToString()); + EXPECT_EQ(expected, literal.ToString()); } TEST_F(LiteralUtilTest, R3ToString) { @@ -158,13 +157,13 @@ TEST_F(LiteralUtilTest, R3ToString) { { { 5 }, { 6 } } })"; - ASSERT_EQ(expected, literal->ToString()); + EXPECT_EQ(expected, literal.ToString()); } TEST_F(LiteralUtilTest, TupleToString) { auto scalar = LiteralUtil::CreateR0(1.0); auto matrix = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); - auto tuple = LiteralUtil::MakeTuple({scalar.get(), matrix.get()}); + auto tuple = LiteralUtil::MakeTuple({&scalar, &matrix}); const string expected = R"((f32[], f32[2,2]) ( 1, f32[2,2] { @@ -172,7 +171,7 @@ f32[2,2] { { 3, 4 } } ))"; - ASSERT_EQ(expected, tuple->ToString()); + EXPECT_EQ(expected, tuple.ToString()); } TEST_F(LiteralUtilTest, CreateR3FromArray3d) { @@ -188,8 +187,8 @@ TEST_F(LiteralUtilTest, CreateR3FromArray3d) { // clang-format on auto literal = LiteralUtil::CreateR3FromArray3D(array_3d); - EXPECT_THAT(literal->shape().dimensions(), ElementsAre(2, 3, 2)); - string result = literal->ToString(); + EXPECT_THAT(literal.shape().dimensions(), ElementsAre(2, 3, 2)); + string result = literal.ToString(); const string expected = R"(f32[2,3,2] { { { 1, 2 }, { 3, 4 }, @@ -198,7 +197,7 @@ TEST_F(LiteralUtilTest, CreateR3FromArray3d) { { 9, 10 }, { 11, 12 } } })"; - ASSERT_EQ(expected, result); + EXPECT_EQ(expected, result); } TEST_F(LiteralUtilTest, CreateSparse) { @@ -221,10 +220,20 @@ TEST_F(LiteralUtilTest, CreateSparse) { }; std::vector expected_values = {8, 9, 7, 10}; - EXPECT_EQ(literal->sparse_indices()->data(), - ArraySlice(expected_indices.data(), - expected_indices.num_elements())); - EXPECT_EQ(literal->data(), ArraySlice(expected_values)); + EXPECT_EQ(literal.sparse_indices()->data(), + absl::Span(expected_indices.data(), + expected_indices.num_elements())); + EXPECT_EQ(literal.data(), absl::Span(expected_values)); + + // Serialize then deserialize and verify the resulting literal. + TF_ASSERT_OK_AND_ASSIGN(Literal literal_from_proto, + Literal::CreateFromProto(literal.ToProto())); + + EXPECT_EQ(literal_from_proto.sparse_indices()->data(), + absl::Span(expected_indices.data(), + expected_indices.num_elements())); + EXPECT_EQ(literal_from_proto.data(), + absl::Span(expected_values)); } TEST_F(LiteralUtilTest, LiteralR4F32ProjectedStringifies) { @@ -235,8 +244,8 @@ TEST_F(LiteralUtilTest, LiteralR4F32ProjectedStringifies) { {2001, 2002}, }, /*projection_p=*/1, /*projection_z=*/2); // clang-format on - EXPECT_THAT(literal->shape().dimensions(), ElementsAre(1, 2, 3, 2)); - string result = literal->ToString(); + EXPECT_THAT(literal.shape().dimensions(), ElementsAre(1, 2, 3, 2)); + string result = literal.ToString(); const string expected = R"(f32[1,2,3,2] { { /*i0=0*/ { /*i1=0*/ @@ -251,13 +260,13 @@ TEST_F(LiteralUtilTest, LiteralR4F32ProjectedStringifies) { } } })"; - ASSERT_EQ(expected, result); + EXPECT_EQ(expected, result); } TEST_F(LiteralUtilTest, LiteralR4F32Stringifies) { - EXPECT_THAT(literal_r4_2x2x3x3_dim0major_->shape().dimensions(), + EXPECT_THAT(literal_r4_2x2x3x3_dim0major_.shape().dimensions(), ElementsAre(2, 2, 3, 3)); - string result = literal_r4_2x2x3x3_dim0major_->ToString(); + string result = literal_r4_2x2x3x3_dim0major_.ToString(); const string expected = R"(f32[2,2,3,3] { { /*i0=0*/ { /*i1=0*/ @@ -284,7 +293,7 @@ TEST_F(LiteralUtilTest, LiteralR4F32Stringifies) { } } })"; - ASSERT_EQ(expected, result); + EXPECT_EQ(expected, result); } TEST_F(LiteralUtilTest, EachCellR2F32) { @@ -295,8 +304,8 @@ TEST_F(LiteralUtilTest, EachCellR2F32) { }); // clang-format on std::vector> seen; - literal->EachCellAsString( - [&seen](ArraySlice indices, const string& value) { + literal.EachCellAsString( + [&seen](absl::Span indices, const string& value) { seen.emplace_back(indices[0], indices[1], value); }); @@ -311,14 +320,14 @@ TEST_F(LiteralUtilTest, ScalarEquality) { auto f32_42 = LiteralUtil::CreateR0(42.0); auto f32_42_clone = LiteralUtil::CreateR0(42.0); - EXPECT_EQ(*f32_42, *f32_42); - EXPECT_EQ(*f32_42, *f32_42_clone); + EXPECT_EQ(f32_42, f32_42); + EXPECT_EQ(f32_42, f32_42_clone); auto f32_123 = LiteralUtil::CreateR0(123.0); - EXPECT_NE(*f32_42, *f32_123); + EXPECT_NE(f32_42, f32_123); auto f64_42 = LiteralUtil::CreateR0(42.0); - EXPECT_NE(*f32_42, *f64_42); + EXPECT_NE(f32_42, f64_42); } TEST_F(LiteralUtilTest, NonScalarEquality) { @@ -331,12 +340,12 @@ TEST_F(LiteralUtilTest, NonScalarEquality) { auto scalar = LiteralUtil::CreateR0(1.0); Literal nil(ShapeUtil::MakeNil()); - EXPECT_EQ(*matrix, *matrix); - EXPECT_EQ(*matrix, *matrix_clone); - EXPECT_NE(*matrix, *matrix_different); - EXPECT_NE(*matrix, *vector_literal); - EXPECT_NE(*matrix, *scalar); - EXPECT_NE(*matrix, nil); + EXPECT_EQ(matrix, matrix); + EXPECT_EQ(matrix, matrix_clone); + EXPECT_NE(matrix, matrix_different); + EXPECT_NE(matrix, vector_literal); + EXPECT_NE(matrix, scalar); + EXPECT_NE(matrix, nil); EXPECT_EQ(nil, nil); } @@ -345,57 +354,54 @@ TEST_F(LiteralUtilTest, TokenEquality) { auto token1 = LiteralUtil::CreateToken(); auto scalar = LiteralUtil::CreateR0(1.0); - EXPECT_EQ(*token0, *token1); - EXPECT_NE(*token0, *scalar); + EXPECT_EQ(token0, token1); + EXPECT_NE(token0, scalar); - EXPECT_EQ(*LiteralUtil::MakeTuple({token0.get()}), - *LiteralUtil::MakeTuple({token0.get()})); - EXPECT_EQ(*LiteralUtil::MakeTuple({token0.get(), scalar.get()}), - *LiteralUtil::MakeTuple({token1.get(), scalar.get()})); - EXPECT_NE(*LiteralUtil::MakeTuple({token0.get(), scalar.get()}), - *LiteralUtil::MakeTuple({scalar.get(), token1.get()})); + EXPECT_EQ(LiteralUtil::MakeTuple({&token0}), + LiteralUtil::MakeTuple({&token0})); + EXPECT_EQ(LiteralUtil::MakeTuple({&token0, &scalar}), + LiteralUtil::MakeTuple({&token1, &scalar})); + EXPECT_NE(LiteralUtil::MakeTuple({&token0, &scalar}), + LiteralUtil::MakeTuple({&scalar, &token1})); } TEST_F(LiteralUtilTest, DifferentLayoutEquality) { // Test equality with literals which have different layouts. - auto colmajor = absl::make_unique( - ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1})); - colmajor->Set({0, 0}, 1.0); - colmajor->Set({0, 1}, 2.0); - colmajor->Set({1, 0}, 3.0); - colmajor->Set({1, 1}, 4.0); + Literal colmajor(ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1})); + colmajor.Set({0, 0}, 1.0); + colmajor.Set({0, 1}, 2.0); + colmajor.Set({1, 0}, 3.0); + colmajor.Set({1, 1}, 4.0); - auto rowmajor = absl::make_unique( - ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0})); - rowmajor->Set({0, 0}, 1.0); - rowmajor->Set({0, 1}, 2.0); - rowmajor->Set({1, 0}, 3.0); - rowmajor->Set({1, 1}, 4.0); + Literal rowmajor(ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0})); + rowmajor.Set({0, 0}, 1.0); + rowmajor.Set({0, 1}, 2.0); + rowmajor.Set({1, 0}, 3.0); + rowmajor.Set({1, 1}, 4.0); - EXPECT_EQ(*rowmajor, *colmajor); + EXPECT_EQ(rowmajor, colmajor); } TEST_F(LiteralUtilTest, TupleEquality) { // Test equality with tuples. auto scalar = LiteralUtil::CreateR0(1.0); auto matrix = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); - auto tuple1 = LiteralUtil::MakeTuple({scalar.get(), matrix.get()}); + auto tuple1 = LiteralUtil::MakeTuple({&scalar, &matrix}); // Tuple with the same elements. One element is shared with the original // tuple, the other is a clone of the element in the original tuple. auto scalar_clone = LiteralUtil::CreateR0(1.0); - auto tuple2 = LiteralUtil::MakeTuple({scalar_clone.get(), matrix.get()}); - EXPECT_EQ(*tuple1, *tuple2); + auto tuple2 = LiteralUtil::MakeTuple({&scalar_clone, &matrix}); + EXPECT_EQ(tuple1, tuple2); // Tuple with elements reversed. - auto reversed_tuple = LiteralUtil::MakeTuple({matrix.get(), scalar.get()}); - EXPECT_NE(*tuple1, *reversed_tuple); + auto reversed_tuple = LiteralUtil::MakeTuple({&matrix, &scalar}); + EXPECT_NE(tuple1, reversed_tuple); // Tuple with different value. auto scalar_42 = LiteralUtil::CreateR0(42.0); - auto different_tuple = - LiteralUtil::MakeTuple({scalar_42.get(), matrix.get()}); - EXPECT_NE(*tuple1, *different_tuple); + auto different_tuple = LiteralUtil::MakeTuple({&scalar_42, &matrix}); + EXPECT_NE(tuple1, different_tuple); } TEST_F(LiteralUtilTest, C64Equality) { @@ -406,162 +412,161 @@ TEST_F(LiteralUtilTest, C64Equality) { // tuple, the other is a clone of the element in the original tuple. auto vector_clone = LiteralUtil::CreateR1({{1.0, 2.0}, {3.0, 4.0}}); - EXPECT_EQ(*vector, *vector_clone); + EXPECT_EQ(vector, vector_clone); auto vector_reversed = LiteralUtil::CreateR1({{3.0, 4.0}, {1.0, 2.0}}); - EXPECT_NE(*vector, *vector_reversed); + EXPECT_NE(vector, vector_reversed); } TEST_F(LiteralUtilTest, IsAllTuple) { auto element1 = LiteralUtil::CreateR0(0.0); auto element2 = LiteralUtil::CreateR2({{0.0, 0.0}, {0.0, 0.0}}); - auto tuple = LiteralUtil::MakeTuple({element1.get(), element1.get()}); + auto tuple = LiteralUtil::MakeTuple({&element1, &element1}); // Tuples should always return false for IsAll. - EXPECT_FALSE(tuple->IsAll(0)); - EXPECT_FALSE(tuple->IsAll(1)); + EXPECT_FALSE(tuple.IsAll(0)); + EXPECT_FALSE(tuple.IsAll(1)); } // Verifies that CreateFromShape works for tuples. TEST_F(LiteralUtilTest, CreateFromShapeTuple) { auto scalar = LiteralUtil::CreateR0(0.0); auto matrix = LiteralUtil::CreateR2({{0, 0}, {0, 0}}); - auto tuple = LiteralUtil::MakeTuple({scalar.get(), matrix.get()}); + auto tuple = LiteralUtil::MakeTuple({&scalar, &matrix}); - auto x = Literal::CreateFromShape(tuple->shape()); - EXPECT_EQ(*tuple, *x); + auto x = Literal::CreateFromShape(tuple.shape()); + EXPECT_EQ(tuple, x); } TEST_F(LiteralUtilTest, IsAll) { - EXPECT_TRUE(LiteralUtil::CreateR0(false)->IsAll(0)); - EXPECT_TRUE(LiteralUtil::CreateR0(true)->IsAll(1)); - EXPECT_FALSE(LiteralUtil::CreateR0(false)->IsAll(1)); - EXPECT_FALSE(LiteralUtil::CreateR0(false)->IsAll(2)); - EXPECT_FALSE(LiteralUtil::CreateR0(true)->IsAll(0)); - EXPECT_FALSE(LiteralUtil::CreateR0(true)->IsAll(2)); - EXPECT_FALSE(LiteralUtil::CreateR0(true)->IsAll(-1)); + EXPECT_TRUE(LiteralUtil::CreateR0(false).IsAll(0)); + EXPECT_TRUE(LiteralUtil::CreateR0(true).IsAll(1)); + EXPECT_FALSE(LiteralUtil::CreateR0(false).IsAll(1)); + EXPECT_FALSE(LiteralUtil::CreateR0(false).IsAll(2)); + EXPECT_FALSE(LiteralUtil::CreateR0(true).IsAll(0)); + EXPECT_FALSE(LiteralUtil::CreateR0(true).IsAll(2)); + EXPECT_FALSE(LiteralUtil::CreateR0(true).IsAll(-1)); // We shouldn't reinterpret int8_min as an unsigned type and then decide that // it is equal to 255. auto int8_min = std::numeric_limits::min(); - EXPECT_FALSE(LiteralUtil::CreateR0(255)->IsAll(int8_min)); + EXPECT_FALSE(LiteralUtil::CreateR0(255).IsAll(int8_min)); - EXPECT_TRUE(LiteralUtil::CreateR0(42.0)->IsAll(42)); - EXPECT_FALSE(LiteralUtil::CreateR0(42.0001)->IsAll(42)); + EXPECT_TRUE(LiteralUtil::CreateR0(42.0).IsAll(42)); + EXPECT_FALSE(LiteralUtil::CreateR0(42.0001).IsAll(42)); - EXPECT_TRUE(LiteralUtil::CreateR1({100, 100, 100})->IsAll(100)); - EXPECT_FALSE(LiteralUtil::CreateR1({100, 100, 100.001})->IsAll(100)); + EXPECT_TRUE(LiteralUtil::CreateR1({100, 100, 100}).IsAll(100)); + EXPECT_FALSE(LiteralUtil::CreateR1({100, 100, 100.001}).IsAll(100)); - EXPECT_TRUE(LiteralUtil::CreateR2({{8, 8}, {8, 8}})->IsAll(8)); - EXPECT_FALSE(LiteralUtil::CreateR2({{8, 8}, {8, 9}})->IsAll(8)); - EXPECT_FALSE(LiteralUtil::CreateR2({{9, 8}, {8, 8}})->IsAll(8)); + EXPECT_TRUE(LiteralUtil::CreateR2({{8, 8}, {8, 8}}).IsAll(8)); + EXPECT_FALSE(LiteralUtil::CreateR2({{8, 8}, {8, 9}}).IsAll(8)); + EXPECT_FALSE(LiteralUtil::CreateR2({{9, 8}, {8, 8}}).IsAll(8)); half h8(8.0f); half h9(9.0f); - EXPECT_TRUE(LiteralUtil::CreateR2({{h8}, {h8}})->IsAll(8)); - EXPECT_FALSE(LiteralUtil::CreateR2({{h8}, {h9}})->IsAll(8)); - EXPECT_FALSE(LiteralUtil::CreateR2({{h9}, {h8}})->IsAll(8)); + EXPECT_TRUE(LiteralUtil::CreateR2({{h8}, {h8}}).IsAll(8)); + EXPECT_FALSE(LiteralUtil::CreateR2({{h8}, {h9}}).IsAll(8)); + EXPECT_FALSE(LiteralUtil::CreateR2({{h9}, {h8}}).IsAll(8)); bfloat16 b8(8.0f); bfloat16 b9(9.0f); - EXPECT_TRUE(LiteralUtil::CreateR2({{b8}, {b8}})->IsAll(8)); - EXPECT_FALSE(LiteralUtil::CreateR2({{b8}, {b9}})->IsAll(8)); - EXPECT_FALSE(LiteralUtil::CreateR2({{b9}, {b8}})->IsAll(8)); + EXPECT_TRUE(LiteralUtil::CreateR2({{b8}, {b8}}).IsAll(8)); + EXPECT_FALSE(LiteralUtil::CreateR2({{b8}, {b9}}).IsAll(8)); + EXPECT_FALSE(LiteralUtil::CreateR2({{b9}, {b8}}).IsAll(8)); // 9.001 will be truncated to 9.0 bfloat16 b91(9.001f); bfloat16 b90(9.00f); - EXPECT_TRUE(LiteralUtil::CreateR2({{b91}, {b90}})->IsAll(9.0)); + EXPECT_TRUE(LiteralUtil::CreateR2({{b91}, {b90}}).IsAll(9.0)); complex64 c8_9 = {8, 9}; - EXPECT_FALSE(LiteralUtil::CreateR2({{c8_9}, {c8_9}})->IsAll(8)); + EXPECT_FALSE(LiteralUtil::CreateR2({{c8_9}, {c8_9}}).IsAll(8)); auto uint64_max = std::numeric_limits::max(); EXPECT_FALSE(LiteralUtil::CreateR2( {{uint64_max, uint64_max}, {uint64_max, uint64_max}}) - ->IsAll(-1)); + .IsAll(-1)); } TEST_F(LiteralUtilTest, IsAllFloat) { // IsAllFloat always returns false when the literal is not floating-point. - EXPECT_FALSE(LiteralUtil::CreateR0(false)->IsAllFloat(0)); - EXPECT_FALSE(LiteralUtil::CreateR0(0)->IsAllFloat(0)); - EXPECT_FALSE(LiteralUtil::CreateR0(0)->IsAllFloat(0)); - EXPECT_FALSE(LiteralUtil::CreateR0(0)->IsAllFloat(0)); - - EXPECT_TRUE(LiteralUtil::CreateR0(0)->IsAllFloat(0)); - EXPECT_TRUE(LiteralUtil::CreateR0(.5)->IsAllFloat(.5)); - EXPECT_TRUE(LiteralUtil::CreateR0(-.5)->IsAllFloat(-.5)); - EXPECT_FALSE(LiteralUtil::CreateR0(-.5)->IsAllFloat(-.49)); + EXPECT_FALSE(LiteralUtil::CreateR0(false).IsAllFloat(0)); + EXPECT_FALSE(LiteralUtil::CreateR0(0).IsAllFloat(0)); + EXPECT_FALSE(LiteralUtil::CreateR0(0).IsAllFloat(0)); + EXPECT_FALSE(LiteralUtil::CreateR0(0).IsAllFloat(0)); + + EXPECT_TRUE(LiteralUtil::CreateR0(0).IsAllFloat(0)); + EXPECT_TRUE(LiteralUtil::CreateR0(.5).IsAllFloat(.5)); + EXPECT_TRUE(LiteralUtil::CreateR0(-.5).IsAllFloat(-.5)); + EXPECT_FALSE(LiteralUtil::CreateR0(-.5).IsAllFloat(-.49)); EXPECT_FALSE( - LiteralUtil::CreateR2({{0, 0, 0}, {0, .1, 0}})->IsAllFloat(0)); + LiteralUtil::CreateR2({{0, 0, 0}, {0, .1, 0}}).IsAllFloat(0)); EXPECT_TRUE(LiteralUtil::CreateR2({{.5, .5, .5}, {.5, .5, .5}}) - ->IsAllFloat(.5)); + .IsAllFloat(.5)); - EXPECT_TRUE(LiteralUtil::CreateR0(0)->IsAllFloat(0)); - EXPECT_TRUE(LiteralUtil::CreateR0(.5)->IsAllFloat(.5)); - EXPECT_TRUE(LiteralUtil::CreateR0(-.5)->IsAllFloat(-.5)); - EXPECT_FALSE(LiteralUtil::CreateR0(-.5)->IsAllFloat(-.49)); + EXPECT_TRUE(LiteralUtil::CreateR0(0).IsAllFloat(0)); + EXPECT_TRUE(LiteralUtil::CreateR0(.5).IsAllFloat(.5)); + EXPECT_TRUE(LiteralUtil::CreateR0(-.5).IsAllFloat(-.5)); + EXPECT_FALSE(LiteralUtil::CreateR0(-.5).IsAllFloat(-.49)); EXPECT_FALSE( - LiteralUtil::CreateR2({{0, 0, 0}, {0, .1, 0}})->IsAllFloat(0)); + LiteralUtil::CreateR2({{0, 0, 0}, {0, .1, 0}}).IsAllFloat(0)); } TEST_F(LiteralUtilTest, IsAllComplex) { // IsAllComplex always returns false when the literal is not complex. - EXPECT_FALSE(LiteralUtil::CreateR0(false)->IsAllComplex(0)); - EXPECT_FALSE(LiteralUtil::CreateR0(0)->IsAllComplex(0)); - EXPECT_FALSE(LiteralUtil::CreateR0(0)->IsAllComplex(0)); - EXPECT_FALSE(LiteralUtil::CreateR0(0)->IsAllComplex(0)); - EXPECT_FALSE(LiteralUtil::CreateR0(0)->IsAllComplex(0)); - EXPECT_FALSE(LiteralUtil::CreateR0(0)->IsAllComplex(0)); + EXPECT_FALSE(LiteralUtil::CreateR0(false).IsAllComplex(0)); + EXPECT_FALSE(LiteralUtil::CreateR0(0).IsAllComplex(0)); + EXPECT_FALSE(LiteralUtil::CreateR0(0).IsAllComplex(0)); + EXPECT_FALSE(LiteralUtil::CreateR0(0).IsAllComplex(0)); + EXPECT_FALSE(LiteralUtil::CreateR0(0).IsAllComplex(0)); + EXPECT_FALSE(LiteralUtil::CreateR0(0).IsAllComplex(0)); complex64 c8_9 = {8, 9}; complex64 c7_9 = {7, 9}; EXPECT_TRUE(LiteralUtil::CreateR2({{c8_9}, {c8_9}}) - ->IsAllComplex({8.0f, 9.0f})); + .IsAllComplex({8.0f, 9.0f})); EXPECT_FALSE(LiteralUtil::CreateR2({{c7_9}, {c8_9}}) - ->IsAllComplex({8.0f, 9.0f})); + .IsAllComplex({8.0f, 9.0f})); EXPECT_FALSE(LiteralUtil::CreateR2({{c8_9}, {c7_9}}) - ->IsAllComplex({8.0f, 9.0f})); + .IsAllComplex({8.0f, 9.0f})); } TEST_F(LiteralUtilTest, IsAllFirst) { // IsAllComplex always returns false when the literal is not complex. - EXPECT_FALSE(LiteralUtil::CreateR1({false, true})->IsAllFirst()); - EXPECT_TRUE(LiteralUtil::CreateR1({false, false})->IsAllFirst()); - EXPECT_FALSE(LiteralUtil::CreateR1({1, 1, 2})->IsAllFirst()); - EXPECT_TRUE(LiteralUtil::CreateR1({5, 5, 5, 5})->IsAllFirst()); - EXPECT_FALSE(LiteralUtil::CreateR1({1, 1, 2})->IsAllFirst()); - EXPECT_TRUE(LiteralUtil::CreateR1({5, 5, 5, 5})->IsAllFirst()); - EXPECT_FALSE(LiteralUtil::CreateR1({1, 1, 2})->IsAllFirst()); - EXPECT_TRUE(LiteralUtil::CreateR1({5, 5, 5, 5})->IsAllFirst()); - EXPECT_FALSE(LiteralUtil::CreateR1({1, 1, 2})->IsAllFirst()); + EXPECT_FALSE(LiteralUtil::CreateR1({false, true}).IsAllFirst()); + EXPECT_TRUE(LiteralUtil::CreateR1({false, false}).IsAllFirst()); + EXPECT_FALSE(LiteralUtil::CreateR1({1, 1, 2}).IsAllFirst()); + EXPECT_TRUE(LiteralUtil::CreateR1({5, 5, 5, 5}).IsAllFirst()); + EXPECT_FALSE(LiteralUtil::CreateR1({1, 1, 2}).IsAllFirst()); + EXPECT_TRUE(LiteralUtil::CreateR1({5, 5, 5, 5}).IsAllFirst()); + EXPECT_FALSE(LiteralUtil::CreateR1({1, 1, 2}).IsAllFirst()); + EXPECT_TRUE(LiteralUtil::CreateR1({5, 5, 5, 5}).IsAllFirst()); + EXPECT_FALSE(LiteralUtil::CreateR1({1, 1, 2}).IsAllFirst()); complex64 c8_9 = {8, 9}; complex64 c7_9 = {7, 9}; - EXPECT_TRUE(LiteralUtil::CreateR2({{c8_9}, {c8_9}})->IsAllFirst()); - EXPECT_FALSE( - LiteralUtil::CreateR2({{c7_9}, {c8_9}})->IsAllFirst()); + EXPECT_TRUE(LiteralUtil::CreateR2({{c8_9}, {c8_9}}).IsAllFirst()); + EXPECT_FALSE(LiteralUtil::CreateR2({{c7_9}, {c8_9}}).IsAllFirst()); } TEST_F(LiteralUtilTest, IsZero) { auto scalar_zero = LiteralUtil::CreateR0(0.0f); auto scalar_one = LiteralUtil::CreateR0(1.0f); - EXPECT_TRUE(scalar_zero->IsZero({})); - EXPECT_FALSE(scalar_one->IsZero({})); + EXPECT_TRUE(scalar_zero.IsZero({})); + EXPECT_FALSE(scalar_one.IsZero({})); auto array = LiteralUtil::CreateR2({{1, 2, 0, 3}, {1, 0, 1, 2}}); - EXPECT_FALSE(array->IsZero({0, 1})); - EXPECT_TRUE(array->IsZero({0, 2})); - EXPECT_TRUE(array->IsZero({1, 1})); - EXPECT_FALSE(array->IsZero({1, 2})); + EXPECT_FALSE(array.IsZero({0, 1})); + EXPECT_TRUE(array.IsZero({0, 2})); + EXPECT_TRUE(array.IsZero({1, 1})); + EXPECT_FALSE(array.IsZero({1, 2})); auto complex_zero = LiteralUtil::CreateR0(0.0f); auto complex_nonzero = LiteralUtil::CreateR0(0.5f); - EXPECT_TRUE(complex_zero->IsZero({})); - EXPECT_FALSE(complex_nonzero->IsZero({})); + EXPECT_TRUE(complex_zero.IsZero({})); + EXPECT_FALSE(complex_nonzero.IsZero({})); } template @@ -577,19 +582,19 @@ TYPED_TEST(LiteralUtilTestTemplated, Relayout2x2) { const Layout layout01 = LayoutUtil::MakeLayout({0, 1}); const Layout layout10 = LayoutUtil::MakeLayout({1, 0}); - auto data01 = data->Relayout(layout01); - EXPECT_TRUE(LayoutUtil::Equal(data01->shape().layout(), layout01)); - EXPECT_EQ(*data, *data01); + auto data01 = data.Relayout(layout01); + EXPECT_TRUE(LayoutUtil::Equal(data01.shape().layout(), layout01)); + EXPECT_EQ(data, data01); - auto data10 = data->Relayout(layout10); - EXPECT_TRUE(LayoutUtil::Equal(data10->shape().layout(), layout10)); - EXPECT_EQ(*data, *data10); + auto data10 = data.Relayout(layout10); + EXPECT_TRUE(LayoutUtil::Equal(data10.shape().layout(), layout10)); + EXPECT_EQ(data, data10); } TEST_F(LiteralUtilTest, ReshapeR0) { auto original = LiteralUtil::CreateR0(1.7f); - auto reshape = original->Reshape(/*dimensions=*/{}).ConsumeValueOrDie(); - EXPECT_EQ(*original, *reshape); + auto reshape = original.Reshape(/*dimensions=*/{}).ConsumeValueOrDie(); + EXPECT_EQ(original, reshape); } TEST_F(LiteralUtilTest, ReshapeR4) { @@ -607,9 +612,9 @@ TEST_F(LiteralUtilTest, ReshapeR4) { {{26, 27}, {28, 29}, {30, 31}, {32, 33}}, }, layout_r3_dim0major_); // clang-format on - auto reshape = original->Reshape({3, 4, 2}).ConsumeValueOrDie(); + auto reshape = original.Reshape({3, 4, 2}).ConsumeValueOrDie(); - EXPECT_EQ(*expected, *reshape); + EXPECT_EQ(expected, reshape); } TEST_F(LiteralUtilTest, ReshapeR4Dim0Minor) { @@ -627,15 +632,15 @@ TEST_F(LiteralUtilTest, ReshapeR4Dim0Minor) { {{26, 27}, {28, 29}, {30, 31}, {32, 33}}, }, layout_r3_dim0major_); // clang-format on - auto reshape = original->Reshape({3, 4, 2}).ConsumeValueOrDie(); + auto reshape = original.Reshape({3, 4, 2}).ConsumeValueOrDie(); - EXPECT_EQ(*expected, *reshape); + EXPECT_EQ(expected, reshape); } TEST_F(LiteralUtilTest, TransposeR0) { auto original = LiteralUtil::CreateR0(1.7f); - auto reshape = original->Transpose(/*permutation=*/{}); - EXPECT_EQ(*original, *reshape); + auto reshape = original.Transpose(/*permutation=*/{}); + EXPECT_EQ(original, reshape); } TEST_F(LiteralUtilTest, TransposeR4) { @@ -647,10 +652,10 @@ TEST_F(LiteralUtilTest, TransposeR4) { {{26, 27, 28, 29}, {30, 31, 32, 33}}, }}); // clang-format on - auto reshape = original->Transpose(/*permutation=*/{2, 3, 0, 1}); + auto reshape = original.Transpose(/*permutation=*/{2, 3, 0, 1}); - reshape->EachCell([&](ArraySlice indices, float value) { - EXPECT_EQ(value, original->Get( + reshape.EachCell([&](absl::Span indices, float value) { + EXPECT_EQ(value, original.Get( {indices[2], indices[3], indices[0], indices[1]})); }); } @@ -659,35 +664,35 @@ TEST_F(LiteralUtilTest, TestR4RelayoutEquivalence) { // Tests that using Relayout on an array is equivalent to creating it in the // target layout in the first place. auto dim0minor_relaid_to_dim0major = - literal_r4_2x2x3x3_dim0minor_->Relayout(layout_r4_dim0major_); - EXPECT_EQ(*literal_r4_2x2x3x3_dim0major_, *dim0minor_relaid_to_dim0major); + literal_r4_2x2x3x3_dim0minor_.Relayout(layout_r4_dim0major_); + EXPECT_EQ(literal_r4_2x2x3x3_dim0major_, dim0minor_relaid_to_dim0major); auto dim0major_relaid_to_dim0minor = - literal_r4_2x2x3x3_dim0major_->Relayout(layout_r4_dim0minor_); - EXPECT_EQ(*literal_r4_2x2x3x3_dim0minor_, *dim0major_relaid_to_dim0minor); + literal_r4_2x2x3x3_dim0major_.Relayout(layout_r4_dim0minor_); + EXPECT_EQ(literal_r4_2x2x3x3_dim0minor_, dim0major_relaid_to_dim0minor); } TEST_F(LiteralUtilTest, TestR2LinearLayout) { // Test expected memory layout of R2 dim0-minor (column-major) literal. auto mat_dim0minor = LiteralUtil::CreateR2WithLayout( {{1, 2, 3}, {4, 5, 6}}, layout_r2_dim0minor_); - EXPECT_EQ(mat_dim0minor->element_count(), 6); - EXPECT_THAT(mat_dim0minor->data(), ElementsAre(1, 4, 2, 5, 3, 6)); + EXPECT_EQ(mat_dim0minor.element_count(), 6); + EXPECT_THAT(mat_dim0minor.data(), ElementsAre(1, 4, 2, 5, 3, 6)); // Test expected memory layout when using Relayout to row major. - auto relaid_mat_to_dim0major = mat_dim0minor->Relayout(layout_r2_dim0major_); - EXPECT_THAT(relaid_mat_to_dim0major->data(), + auto relaid_mat_to_dim0major = mat_dim0minor.Relayout(layout_r2_dim0major_); + EXPECT_THAT(relaid_mat_to_dim0major.data(), ElementsAre(1, 2, 3, 4, 5, 6)); // Test expected memory layout of R2 created with dim0-major (row-major). auto mat_dim0major = LiteralUtil::CreateR2WithLayout( {{1, 2, 3}, {4, 5, 6}}, layout_r2_dim0major_); - EXPECT_EQ(mat_dim0major->element_count(), 6); - EXPECT_THAT(mat_dim0major->data(), ElementsAre(1, 2, 3, 4, 5, 6)); + EXPECT_EQ(mat_dim0major.element_count(), 6); + EXPECT_THAT(mat_dim0major.data(), ElementsAre(1, 2, 3, 4, 5, 6)); // Test expected memory layout when using Relayout to column major. - auto relaid_mat_to_dim0minor = mat_dim0major->Relayout(layout_r2_dim0minor_); - EXPECT_THAT(relaid_mat_to_dim0minor->data(), + auto relaid_mat_to_dim0minor = mat_dim0major.Relayout(layout_r2_dim0minor_); + EXPECT_THAT(relaid_mat_to_dim0minor.data(), ElementsAre(1, 4, 2, 5, 3, 6)); } @@ -708,77 +713,77 @@ TEST_F(LiteralUtilTest, TestR3LinearLayout) { auto lit_dim0minor = LiteralUtil::CreateR3FromArray3DWithLayout( arr3d, layout_r3_dim0minor_); - EXPECT_EQ(lit_dim0minor->element_count(), 12); + EXPECT_EQ(lit_dim0minor.element_count(), 12); std::vector expected_dim0minor{1, 7, 4, 10, 2, 8, 5, 11, 3, 9, 6, 12}; - EXPECT_THAT(lit_dim0minor->data(), + EXPECT_THAT(lit_dim0minor.data(), testing::ElementsAreArray(expected_dim0minor)); // Test expected memory layout when using Relayout to row major. - auto relaid_lit_to_dim0major = lit_dim0minor->Relayout(layout_r3_dim0major_); + auto relaid_lit_to_dim0major = lit_dim0minor.Relayout(layout_r3_dim0major_); std::vector expected_dim0major{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; - EXPECT_THAT(relaid_lit_to_dim0major->data(), + EXPECT_THAT(relaid_lit_to_dim0major.data(), testing::ElementsAreArray(expected_dim0major)); // Test expected memory layout of R3 created with dim0-major (row-major). auto lit_dim0major = LiteralUtil::CreateR3FromArray3DWithLayout( arr3d, layout_r3_dim0major_); - EXPECT_EQ(lit_dim0major->element_count(), 12); - EXPECT_THAT(lit_dim0major->data(), + EXPECT_EQ(lit_dim0major.element_count(), 12); + EXPECT_THAT(lit_dim0major.data(), testing::ElementsAreArray(expected_dim0major)); // Test expected memory layout when using Relayout to column major. - auto relaid_lit_to_dim0minor = lit_dim0major->Relayout(layout_r3_dim0minor_); - EXPECT_THAT(relaid_lit_to_dim0minor->data(), + auto relaid_lit_to_dim0minor = lit_dim0major.Relayout(layout_r3_dim0minor_); + EXPECT_THAT(relaid_lit_to_dim0minor.data(), testing::ElementsAreArray(expected_dim0minor)); } TEST_F(LiteralUtilTest, SliceR0S32) { auto input = LiteralUtil::CreateR0(1); - auto result = input->Slice({}, {}); - EXPECT_EQ(*input, *result); + auto result = input.Slice({}, {}); + EXPECT_EQ(input, result); } TEST_F(LiteralUtilTest, SliceR1F32) { auto input = LiteralUtil::CreateR1({1.0, 2.0, 3.0, 4.0, 5.0}); - auto result = input->Slice({3}, {4}); + auto result = input.Slice({3}, {4}); auto expected = LiteralUtil::CreateR1({4.0}); - EXPECT_EQ(*expected, *result); + EXPECT_EQ(expected, result); } TEST_F(LiteralUtilTest, SliceR2U32) { auto input_3x4 = LiteralUtil::CreateR2( {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}); - auto result = input_3x4->Slice({0, 2}, {2, 4}); + auto result = input_3x4.Slice({0, 2}, {2, 4}); auto expected = LiteralUtil::CreateR2({{3, 4}, {7, 8}}); - EXPECT_EQ(*expected, *result); + EXPECT_EQ(expected, result); } TEST_F(LiteralUtilTest, SliceR3U32Full) { auto input_2x3x2 = LiteralUtil::CreateR3( {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}}); - auto result = input_2x3x2->Slice({0, 0, 0}, {2, 3, 2}); - EXPECT_EQ(*input_2x3x2, *result); + auto result = input_2x3x2.Slice({0, 0, 0}, {2, 3, 2}); + EXPECT_EQ(input_2x3x2, result); } TEST_F(LiteralUtilTest, PopulateR1S64) { Literal output(ShapeUtil::MakeShape(S64, {1})); output.PopulateR1({77}); auto expected = LiteralUtil::CreateR1({77}); - EXPECT_EQ(output, *expected); + EXPECT_EQ(output, expected); } TEST_F(LiteralUtilTest, PopulateR1U64) { Literal output(ShapeUtil::MakeShape(U64, {2})); output.PopulateR1({{77, 88}}); auto expected = LiteralUtil::CreateR1({{77, 88}}); - EXPECT_EQ(output, *expected); + EXPECT_EQ(output, expected); } TEST_F(LiteralUtilTest, PopulateR1C64) { Literal output(ShapeUtil::MakeShape(C64, {1})); output.PopulateR1({{77, 88}}); auto expected = LiteralUtil::CreateR1({{77, 88}}); - EXPECT_EQ(output, *expected); + EXPECT_EQ(output, expected); } TEST_F(LiteralUtilTest, PopulateR2C64) { @@ -786,7 +791,7 @@ TEST_F(LiteralUtilTest, PopulateR2C64) { output.PopulateR2({{{7, 8}, {9, 10}}, {{1, 2}, {3, 4}}}); auto expected = LiteralUtil::CreateR2({{{7, 8}, {9, 10}}, {{1, 2}, {3, 4}}}); - EXPECT_EQ(output, *expected); + EXPECT_EQ(output, expected); } TEST_F(LiteralUtilTest, PopulateWithValueR0BF16) { @@ -794,7 +799,7 @@ TEST_F(LiteralUtilTest, PopulateWithValueR0BF16) { bfloat16 h(0.25f); output.PopulateWithValue(h); auto expected = LiteralUtil::CreateR0(h); - EXPECT_EQ(output, *expected); + EXPECT_EQ(output, expected); } TEST_F(LiteralUtilTest, PopulateWithValueR1BF16) { @@ -802,7 +807,7 @@ TEST_F(LiteralUtilTest, PopulateWithValueR1BF16) { bfloat16 h(0.5f); output.PopulateWithValue(h); auto expected = LiteralUtil::CreateR1({h, h, h}); - EXPECT_EQ(output, *expected); + EXPECT_EQ(output, expected); } TEST_F(LiteralUtilTest, PopulateWithValueR2BF16) { @@ -810,28 +815,28 @@ TEST_F(LiteralUtilTest, PopulateWithValueR2BF16) { bfloat16 h(2.0f); output.PopulateWithValue(h); auto expected = LiteralUtil::CreateR2({{h, h}, {h, h}}); - EXPECT_EQ(output, *expected); + EXPECT_EQ(output, expected); } TEST_F(LiteralUtilTest, PopulateWithValueR0F32) { Literal output(ShapeUtil::MakeShape(F32, {})); output.PopulateWithValue(2.5f); auto expected = LiteralUtil::CreateR0(2.5f); - EXPECT_EQ(output, *expected); + EXPECT_EQ(output, expected); } TEST_F(LiteralUtilTest, PopulateWithValueR1S64) { Literal output(ShapeUtil::MakeShape(S64, {3})); output.PopulateWithValue(-7); auto expected = LiteralUtil::CreateR1({-7, -7, -7}); - EXPECT_EQ(output, *expected); + EXPECT_EQ(output, expected); } TEST_F(LiteralUtilTest, PopulateWithValueR2U64) { Literal output(ShapeUtil::MakeShape(U64, {2, 2})); output.PopulateWithValue(42); auto expected = LiteralUtil::CreateR2({{42, 42}, {42, 42}}); - EXPECT_EQ(output, *expected); + EXPECT_EQ(output, expected); } TEST_F(LiteralUtilTest, PopulateWithValueR2C64) { @@ -839,7 +844,7 @@ TEST_F(LiteralUtilTest, PopulateWithValueR2C64) { output.PopulateWithValue({4, 2}); auto expected = LiteralUtil::CreateR2({{{4, 2}, {4, 2}}, {{4, 2}, {4, 2}}}); - EXPECT_EQ(output, *expected); + EXPECT_EQ(output, expected); } TEST_F(LiteralUtilTest, PopulateWithValueR0F16) { @@ -847,7 +852,7 @@ TEST_F(LiteralUtilTest, PopulateWithValueR0F16) { half h(0.25f); output.PopulateWithValue(h); auto expected = LiteralUtil::CreateR0(h); - EXPECT_EQ(output, *expected); + EXPECT_EQ(output, expected); } TEST_F(LiteralUtilTest, PopulateWithValueR1F16) { @@ -855,7 +860,7 @@ TEST_F(LiteralUtilTest, PopulateWithValueR1F16) { half h(0.5f); output.PopulateWithValue(h); auto expected = LiteralUtil::CreateR1({h, h, h}); - EXPECT_EQ(output, *expected); + EXPECT_EQ(output, expected); } TEST_F(LiteralUtilTest, PopulateWithValueR2F16) { @@ -863,18 +868,18 @@ TEST_F(LiteralUtilTest, PopulateWithValueR2F16) { half h(2.0f); output.PopulateWithValue(h); auto expected = LiteralUtil::CreateR2({{h, h}, {h, h}}); - EXPECT_EQ(output, *expected); + EXPECT_EQ(output, expected); } TEST_F(LiteralUtilTest, ReplicateR2U32) { auto input = LiteralUtil::CreateR2( {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}); - auto output = input->Replicate(3); + auto output = input.Replicate(3); auto expected = LiteralUtil::CreateR3( {{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}, {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}, {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}}); - EXPECT_EQ(*output, *expected); + EXPECT_EQ(output, expected); } TEST_F(LiteralUtilTest, CopySliceFrom) { @@ -889,35 +894,35 @@ TEST_F(LiteralUtilTest, CopySliceFrom) { const int64 zero_base[] = {0, 0, 0, 0}; const int64 step[] = {1, 1, 1, 1}; uint32 seqnr = 0; - auto init_proc = [&](ArraySlice indexes) { - source->Set(indexes, ++seqnr); + auto init_proc = [&](absl::Span indexes) { + source.Set(indexes, ++seqnr); return true; }; - ShapeUtil::ForEachIndex(source->shape(), zero_base, dimensions, step, + ShapeUtil::ForEachIndex(source.shape(), zero_base, dimensions, step, init_proc); auto blank = Literal::CreateFromShape(shape); const int64 src_base[] = {3, 1, 5, 7}; const int64 dest_base[] = {6, 4, 12, 2}; const int64 copy_size[] = {7, 8, 11, 9}; - TF_EXPECT_OK(blank->CopySliceFrom(*source, src_base, dest_base, copy_size)); + TF_EXPECT_OK(blank.CopySliceFrom(source, src_base, dest_base, copy_size)); std::vector source_indexes(TF_ARRAYSIZE(dimensions), 0); std::vector blank_indexes(TF_ARRAYSIZE(dimensions), 0); bool matched = true; - auto check_proc = [&](ArraySlice indexes) { + auto check_proc = [&](absl::Span indexes) { std::copy(indexes.begin(), indexes.end(), source_indexes.begin()); std::transform(source_indexes.begin(), source_indexes.end(), src_base, source_indexes.begin(), std::plus()); std::copy(indexes.begin(), indexes.end(), blank_indexes.begin()); std::transform(blank_indexes.begin(), blank_indexes.end(), dest_base, blank_indexes.begin(), std::plus()); - auto bval = blank->Get(blank_indexes); - matched = (bval != 0 && bval == source->Get(source_indexes)); + auto bval = blank.Get(blank_indexes); + matched = (bval != 0 && bval == source.Get(source_indexes)); return matched; }; - ShapeUtil::ForEachIndex(source->shape(), zero_base, copy_size, step, + ShapeUtil::ForEachIndex(source.shape(), zero_base, copy_size, step, check_proc); EXPECT_TRUE(matched); } @@ -926,14 +931,14 @@ TEST_F(LiteralUtilTest, CopySliceFrom) { TEST_F(LiteralUtilTest, CopyFromScalars) { auto zero = LiteralUtil::CreateR0(0); auto nine = LiteralUtil::CreateR0(9); - TF_EXPECT_OK(zero->CopyFrom(*nine)); - EXPECT_EQ(*zero, *nine); + TF_EXPECT_OK(zero.CopyFrom(nine)); + EXPECT_EQ(zero, nine); auto vect = LiteralUtil::CreateR1({3, 4, 9, 12, 5, 17, 21}); - TF_EXPECT_OK(zero->CopySliceFrom(*vect, {5}, {}, {})); - EXPECT_EQ(zero->Get({}), 17); - TF_EXPECT_OK(vect->CopySliceFrom(*zero, {}, {4}, {})); - EXPECT_EQ(vect->Get({4}), 17); + TF_EXPECT_OK(zero.CopySliceFrom(vect, {5}, {}, {})); + EXPECT_EQ(zero.Get({}), 17); + TF_EXPECT_OK(vect.CopySliceFrom(zero, {}, {4}, {})); + EXPECT_EQ(vect.Get({4}), 17); } TEST_F(LiteralUtilTest, CopyFromAndToZeroElement) { @@ -946,17 +951,17 @@ TEST_F(LiteralUtilTest, CopyFromAndToZeroElement) { const auto empty = Literal::CreateFromShape(empty_r1_shape); auto nine = LiteralUtil::CreateR1({9}); - TF_EXPECT_OK(nine->CopySliceFrom(*empty, {0}, {0}, {0})); - EXPECT_EQ(*nine, *const_nine); + TF_EXPECT_OK(nine.CopySliceFrom(empty, {0}, {0}, {0})); + EXPECT_EQ(nine, const_nine); } { // Copy 0 element to destination with zero elements. - const auto empty = Literal::CreateFromShape(empty_r1_shape); + auto empty = Literal::CreateFromShape(empty_r1_shape); auto nine = LiteralUtil::CreateR1({9}); - TF_EXPECT_OK(empty->CopySliceFrom(*nine, {0}, {0}, {0})); - EXPECT_EQ(*empty, *const_empty); + TF_EXPECT_OK(empty.CopySliceFrom(nine, {0}, {0}, {0})); + EXPECT_EQ(empty, const_empty); } } @@ -970,76 +975,77 @@ TEST_F(LiteralUtilTest, CopyFromNilShape) { TEST_F(LiteralUtilTest, CopyFromArrays) { auto scalar_42 = LiteralUtil::CreateR0(42.0); auto scalar_123 = LiteralUtil::CreateR0(123.0); - EXPECT_NE(*scalar_42, *scalar_123); - TF_ASSERT_OK(scalar_42->CopyFrom(*scalar_123, /*dest_shape_index=*/{}, - /*src_shape_index=*/{})); - EXPECT_EQ(*scalar_42, *scalar_123); - EXPECT_EQ(scalar_42->Get({}), 123.0f); + EXPECT_NE(scalar_42, scalar_123); + TF_ASSERT_OK(scalar_42.CopyFrom(scalar_123, /*dest_shape_index=*/{}, + /*src_shape_index=*/{})); + EXPECT_EQ(scalar_42, scalar_123); + EXPECT_EQ(scalar_42.Get({}), 123.0f); auto matrix_1234 = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); auto matrix_5678 = LiteralUtil::CreateR2({{5.0, 6.0}, {7.0, 8.0}}); - EXPECT_NE(*matrix_1234, *matrix_5678); - EXPECT_EQ(matrix_1234->Get({0, 0}), 1.0f); - TF_ASSERT_OK(matrix_1234->CopyFrom(*matrix_5678, /*dest_shape_index=*/{}, - /*src_shape_index=*/{})); - EXPECT_EQ(*matrix_1234, *matrix_5678); - EXPECT_EQ(matrix_1234->Get({0, 0}), 5.0f); + EXPECT_NE(matrix_1234, matrix_5678); + EXPECT_EQ(matrix_1234.Get({0, 0}), 1.0f); + TF_ASSERT_OK(matrix_1234.CopyFrom(matrix_5678, /*dest_shape_index=*/{}, + /*src_shape_index=*/{})); + EXPECT_EQ(matrix_1234, matrix_5678); + EXPECT_EQ(matrix_1234.Get({0, 0}), 5.0f); } TEST_F(LiteralUtilTest, CopyFromTuples) { auto matrix = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); Literal nil_literal(ShapeUtil::MakeNil()); - auto nested_tuple = LiteralUtil::MakeTuple( - {matrix.get(), - LiteralUtil::MakeTuple( - {LiteralUtil::CreateR0(42).get(), - LiteralUtil::CreateR1({23.0, 44.0}).get(), &nil_literal}) - .get()}); + Literal inner_elements[] = {LiteralUtil::CreateR0(42), + LiteralUtil::CreateR1({23.0, 44.0})}; + Literal inner_tuple = LiteralUtil::MakeTuple( + {&inner_elements[0], &inner_elements[1], &nil_literal}); + Literal nested_tuple = LiteralUtil::MakeTuple({&matrix, &inner_tuple}); // Create a tuple the same shape as the inner tuple of nested_tuple but with // different values.. - auto tuple = LiteralUtil::MakeTuple( - {LiteralUtil::CreateR0(-5).get(), - LiteralUtil::CreateR1({2.0, 4.0}).get(), &nil_literal}); + Literal int32_minus5 = LiteralUtil::CreateR0(-5); + Literal double_2_4 = LiteralUtil::CreateR1({2.0, 4.0}); + Literal tuple = + LiteralUtil::MakeTuple({&int32_minus5, &double_2_4, &nil_literal}); - EXPECT_EQ(*matrix, LiteralSlice(*nested_tuple, {0})); - EXPECT_EQ(nested_tuple->Get({}, {1, 0}), 42); - EXPECT_EQ(nested_tuple->Get({0}, {1, 1}), 23.0); - EXPECT_EQ(nested_tuple->Get({1}, {1, 1}), 44.0); + EXPECT_EQ(matrix, LiteralSlice(nested_tuple, {0})); + EXPECT_EQ(nested_tuple.Get({}, {1, 0}), 42); + EXPECT_EQ(nested_tuple.Get({0}, {1, 1}), 23.0); + EXPECT_EQ(nested_tuple.Get({1}, {1, 1}), 44.0); // Overwrite the inner tuple element of nested_tuple with the contents of // 'tuple'. - TF_ASSERT_OK(nested_tuple->CopyFrom(*tuple, /*dest_shape_index=*/{1}, - /*src_shape_index=*/{})); + TF_ASSERT_OK(nested_tuple.CopyFrom(tuple, /*dest_shape_index=*/{1}, + /*src_shape_index=*/{})); // The matrix element should be unchanged. - EXPECT_EQ(*matrix, LiteralSlice(*nested_tuple, {0})); + EXPECT_EQ(matrix, LiteralSlice(nested_tuple, {0})); // The tuple element should have been copied from 'tuple'. - EXPECT_EQ(nested_tuple->Get({}, {1, 0}), -5); - EXPECT_EQ(nested_tuple->Get({0}, {1, 1}), 2.0); - EXPECT_EQ(nested_tuple->Get({1}, {1, 1}), 4.0); + EXPECT_EQ(nested_tuple.Get({}, {1, 0}), -5); + EXPECT_EQ(nested_tuple.Get({0}, {1, 1}), 2.0); + EXPECT_EQ(nested_tuple.Get({1}, {1, 1}), 4.0); } TEST_F(LiteralUtilTest, CopyBetweenSameTuple) { - auto tuple = LiteralUtil::MakeTuple({LiteralUtil::CreateR0(-2).get(), - LiteralUtil::CreateR0(4).get()}); + Literal elements[] = {LiteralUtil::CreateR0(-2), + LiteralUtil::CreateR0(4)}; + Literal tuple = LiteralUtil::MakeTuple({&elements[0], &elements[1]}); - EXPECT_EQ(tuple->Get({}, {0}), -2); - EXPECT_EQ(tuple->Get({}, {1}), 4); + EXPECT_EQ(tuple.Get({}, {0}), -2); + EXPECT_EQ(tuple.Get({}, {1}), 4); // Copy from one element to the other. - TF_ASSERT_OK(tuple->CopyFrom(*tuple, /*dest_shape_index=*/{1}, - /*src_shape_index=*/{0})); + TF_ASSERT_OK(tuple.CopyFrom(tuple, /*dest_shape_index=*/{1}, + /*src_shape_index=*/{0})); - EXPECT_EQ(tuple->Get({}, {0}), -2); - EXPECT_EQ(tuple->Get({}, {1}), -2); + EXPECT_EQ(tuple.Get({}, {0}), -2); + EXPECT_EQ(tuple.Get({}, {1}), -2); } TEST_F(LiteralUtilTest, CopyFromDifferentShapes) { auto matrix = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); auto vector = LiteralUtil::CreateR1({5.0, 7.0}); - Status status = matrix->CopyFrom(*vector); + Status status = matrix.CopyFrom(vector); ASSERT_FALSE(status.ok()); - ASSERT_THAT(status.error_message(), + EXPECT_THAT(status.error_message(), HasSubstr("Destination subshape incompatible")); } @@ -1047,9 +1053,8 @@ TEST_F(LiteralUtilTest, F16) { // Verify that the internal data views are consistent and that they // are in little endian format // TODO - modify if we make the data format machine endianess dependent - auto m1 = Literal::CreateFromShape(ShapeUtil::MakeShape(F16, {2, 2})); - Literal* l1 = m1.get(); - const char* d1 = reinterpret_cast(l1->data().data()); + Literal m1 = Literal::CreateFromShape(ShapeUtil::MakeShape(F16, {2, 2})); + const char* d1 = reinterpret_cast(m1.data().data()); EXPECT_EQ(d1[0], 0); EXPECT_EQ(d1[1], 0); EXPECT_EQ(d1[2], 0); @@ -1062,8 +1067,7 @@ TEST_F(LiteralUtilTest, F16) { half h1(1.0f); half h2(2.0f); auto m2 = LiteralUtil::CreateR2({{h1, h2}, {h2, h1}}); - Literal* l2 = m2.get(); - const char* d2 = reinterpret_cast(l2->data().data()); + const char* d2 = reinterpret_cast(m2.data().data()); EXPECT_EQ(d2[0], 0); EXPECT_EQ(d2[1], 0x3C); EXPECT_EQ(d2[2], 0); @@ -1092,25 +1096,25 @@ TEST_F(LiteralUtilTest, Populate) { Shape shape = ShapeUtil::MakeShapeWithLayout( primitive_util::NativeToPrimitiveType(), data.dimensions, data.layout); - auto literal = absl::make_unique(shape); - auto generator = [&](ArraySlice indexes) -> uint32 { + Literal literal(shape); + auto generator = [&](absl::Span indexes) -> uint32 { // Offsets from linear index just to avoid R0 literals to be initialized // with zero. - return IndexUtil::MultidimensionalIndexToLinearIndex(literal->shape(), + return IndexUtil::MultidimensionalIndexToLinearIndex(literal.shape(), indexes) + 17; }; - TF_EXPECT_OK(literal->Populate(generator)); + TF_EXPECT_OK(literal.Populate(generator)); std::vector zero_base(data.dimensions.size(), 0); std::vector step(data.dimensions.size(), 1); bool matched = true; - auto check_function = [&](ArraySlice indexes) { - auto value = literal->Get(indexes); + auto check_function = [&](absl::Span indexes) { + auto value = literal.Get(indexes); matched = matched && (value == generator(indexes)); return matched; }; - ShapeUtil::ForEachIndex(literal->shape(), zero_base, data.dimensions, step, + ShapeUtil::ForEachIndex(literal.shape(), zero_base, data.dimensions, step, check_function); EXPECT_TRUE(matched); } @@ -1134,25 +1138,25 @@ TEST_F(LiteralUtilTest, PopulateParallel) { Shape shape = ShapeUtil::MakeShapeWithLayout( primitive_util::NativeToPrimitiveType(), data.dimensions, data.layout); - auto literal = absl::make_unique(shape); - auto generator = [&](ArraySlice indexes) -> uint32 { + Literal literal(shape); + auto generator = [&](absl::Span indexes) -> uint32 { // Offsets from linear index just to avoid R0 literals to be initialized // with zero. - return IndexUtil::MultidimensionalIndexToLinearIndex(literal->shape(), + return IndexUtil::MultidimensionalIndexToLinearIndex(literal.shape(), indexes) + 17; }; - TF_EXPECT_OK(literal->PopulateParallel(generator)); + TF_EXPECT_OK(literal.PopulateParallel(generator)); std::vector zero_base(data.dimensions.size(), 0); std::vector step(data.dimensions.size(), 1); bool matched = true; - auto check_function = [&](ArraySlice indexes) { - auto value = literal->Get(indexes); + auto check_function = [&](absl::Span indexes) { + auto value = literal.Get(indexes); matched = matched && (value == generator(indexes)); return matched; }; - ShapeUtil::ForEachIndex(literal->shape(), zero_base, data.dimensions, step, + ShapeUtil::ForEachIndex(literal.shape(), zero_base, data.dimensions, step, check_function); EXPECT_TRUE(matched); } @@ -1171,10 +1175,9 @@ TEST_F(LiteralUtilTest, ConvertR4) { {{26, 27, 28, 29}, {30, 31, 32, 33}}, }}, layout_r4_dim0major_); // clang-format on - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr converted, - original->Convert(U32)); + TF_ASSERT_OK_AND_ASSIGN(Literal converted, original.Convert(U32)); - EXPECT_EQ(*expected, *converted); + EXPECT_EQ(expected, converted); } TEST_F(LiteralUtilTest, ConvertIfTypesMatch) { @@ -1246,69 +1249,65 @@ TEST_F(LiteralUtilTest, ConvertIfTypesMatch) { {{26.0f, 0.0f, 28.0f, 0.0f}, {0.0f, 31.0f, 0.0f, 33.0f}}, }}, layout_r4_dim0major_); // clang-format on - std::unique_ptr conv; + Literal conv; - conv = s8->Convert(U32).ConsumeValueOrDie(); - EXPECT_EQ(*conv, *u32); + conv = s8.Convert(U32).ConsumeValueOrDie(); + EXPECT_EQ(conv, u32); - conv = s8->Convert(S32).ConsumeValueOrDie(); - EXPECT_EQ(*conv, *s32); + conv = s8.Convert(S32).ConsumeValueOrDie(); + EXPECT_EQ(conv, s32); - conv = s8->Convert(U64).ConsumeValueOrDie(); - EXPECT_EQ(*conv, *u64); + conv = s8.Convert(U64).ConsumeValueOrDie(); + EXPECT_EQ(conv, u64); - conv = s8->Convert(S64).ConsumeValueOrDie(); - EXPECT_EQ(*conv, *s64); + conv = s8.Convert(S64).ConsumeValueOrDie(); + EXPECT_EQ(conv, s64); - conv = s8->Convert(PRED).ConsumeValueOrDie(); - EXPECT_EQ(*conv, *pred); + conv = s8.Convert(PRED).ConsumeValueOrDie(); + EXPECT_EQ(conv, pred); - conv = bf16->Convert(S32).ConsumeValueOrDie(); - EXPECT_EQ(*conv, *s32); + conv = bf16.Convert(S32).ConsumeValueOrDie(); + EXPECT_EQ(conv, s32); - conv = bf16->Convert(F32).ConsumeValueOrDie(); - EXPECT_EQ(*conv, *f32); + conv = bf16.Convert(F32).ConsumeValueOrDie(); + EXPECT_EQ(conv, f32); - conv = pred->Convert(S32).ConsumeValueOrDie(); - EXPECT_EQ(*conv, *int32_pred); + conv = pred.Convert(S32).ConsumeValueOrDie(); + EXPECT_EQ(conv, int32_pred); - conv = f32->Convert(S32).ConsumeValueOrDie(); - EXPECT_EQ(*conv, *s32); + conv = f32.Convert(S32).ConsumeValueOrDie(); + EXPECT_EQ(conv, s32); - conv = f64->Convert(S32).ConsumeValueOrDie(); - EXPECT_EQ(*conv, *s32); + conv = f64.Convert(S32).ConsumeValueOrDie(); + EXPECT_EQ(conv, s32); - conv = s32->Convert(F32).ConsumeValueOrDie(); - EXPECT_EQ(*conv, *f32); + conv = s32.Convert(F32).ConsumeValueOrDie(); + EXPECT_EQ(conv, f32); - conv = f32->Convert(F16).ConsumeValueOrDie(); - EXPECT_EQ(*conv, *f16); + conv = f32.Convert(F16).ConsumeValueOrDie(); + EXPECT_EQ(conv, f16); - conv = f64->Convert(F16).ConsumeValueOrDie(); - EXPECT_EQ(*conv, *f16); + conv = f64.Convert(F16).ConsumeValueOrDie(); + EXPECT_EQ(conv, f16); - conv = s32->Convert(F16).ConsumeValueOrDie(); - EXPECT_EQ(*conv, *f16); + conv = s32.Convert(F16).ConsumeValueOrDie(); + EXPECT_EQ(conv, f16); - conv = u32->Convert(F16).ConsumeValueOrDie(); - EXPECT_EQ(*conv, *f16); + conv = u32.Convert(F16).ConsumeValueOrDie(); + EXPECT_EQ(conv, f16); - conv = s32->Convert(C64).ConsumeValueOrDie(); - EXPECT_EQ(*conv, *c64); + conv = s32.Convert(C64).ConsumeValueOrDie(); + EXPECT_EQ(conv, c64); - conv = f16->Convert(C64).ConsumeValueOrDie(); - EXPECT_EQ(*conv, *c64); + conv = f16.Convert(C64).ConsumeValueOrDie(); + EXPECT_EQ(conv, c64); - EXPECT_EQ(s32->Convert(TUPLE).status().code(), - tensorflow::error::UNIMPLEMENTED); - EXPECT_EQ(s32->Convert(S16).status().code(), - tensorflow::error::UNIMPLEMENTED); - EXPECT_EQ(s32->Convert(U16).status().code(), - tensorflow::error::UNIMPLEMENTED); - EXPECT_EQ(c64->Convert(F32).status().code(), - tensorflow::error::UNIMPLEMENTED); - EXPECT_EQ(c64->Convert(S32).status().code(), + EXPECT_EQ(s32.Convert(TUPLE).status().code(), tensorflow::error::UNIMPLEMENTED); + EXPECT_EQ(s32.Convert(S16).status().code(), tensorflow::error::UNIMPLEMENTED); + EXPECT_EQ(s32.Convert(U16).status().code(), tensorflow::error::UNIMPLEMENTED); + EXPECT_EQ(c64.Convert(F32).status().code(), tensorflow::error::UNIMPLEMENTED); + EXPECT_EQ(c64.Convert(S32).status().code(), tensorflow::error::UNIMPLEMENTED); } TEST_F(LiteralUtilTest, BitcastConvert) { @@ -1318,13 +1317,12 @@ TEST_F(LiteralUtilTest, BitcastConvert) { tensorflow::bit_cast(100.f), 0xbeef}); auto expected = LiteralUtil::CreateR1( {2.5f, -42.25f, 100.0f, tensorflow::bit_cast(0xbeef)}); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr converted, - original->BitcastConvert(F32)); + TF_ASSERT_OK_AND_ASSIGN(Literal converted, original.BitcastConvert(F32)); } TEST_F(LiteralUtilTest, BitcastConvertBetweenInvalidTypes) { auto literal = LiteralUtil::CreateR0(1234); - Status status = literal->BitcastConvert(F64).status(); + Status status = literal.BitcastConvert(F64).status(); EXPECT_NE(Status::OK(), status); EXPECT_TRUE( absl::StrContains(status.error_message(), "bit widths are different")); @@ -1342,11 +1340,10 @@ TEST_F(LiteralUtilTest, CopyFromProto_Bool) { p.add_preds((i % 2) == (len % 2)); } - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr literal, - Literal::CreateFromProto(p)); - ASSERT_EQ(len, literal->data().size()); + TF_ASSERT_OK_AND_ASSIGN(Literal literal, Literal::CreateFromProto(p)); + ASSERT_EQ(len, literal.data().size()); int i = 0; - for (bool value : literal->data()) { + for (bool value : literal.data()) { EXPECT_EQ((i % 2) == (len % 2), value); ++i; } @@ -1359,11 +1356,10 @@ TEST_F(LiteralUtilTest, ToProto_f16) { half h2(2.0f); auto m = LiteralUtil::CreateR2({{h1, h2}, {h2, h1}}); - Literal* l = m.get(); - EXPECT_EQ(4, ShapeUtil::ElementsIn(l->shape())); - EXPECT_EQ(4, l->data().size()); + EXPECT_EQ(4, ShapeUtil::ElementsIn(m.shape())); + EXPECT_EQ(4, m.data().size()); - LiteralProto p = l->ToProto(); + LiteralProto p = m.ToProto(); EXPECT_EQ(4, ShapeUtil::ElementsIn(p.shape())); EXPECT_EQ(8, p.f16s().size()); const char* d = p.f16s().data(); @@ -1390,56 +1386,53 @@ TEST_F(LiteralUtilTest, CopyFromProto_f16) { LayoutUtil::SetToDefaultLayout(p.mutable_shape()); p.clear_f16s(); p.set_f16s(half_vals, 8); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr literal, - Literal::CreateFromProto(p)); - auto r = literal->data(); + TF_ASSERT_OK_AND_ASSIGN(Literal literal, Literal::CreateFromProto(p)); + auto r = literal.data(); ASSERT_EQ(4, r.size()); - ASSERT_EQ(h1, r[0]); - ASSERT_EQ(h2, r[1]); - ASSERT_EQ(h2, r[2]); - ASSERT_EQ(h1, r[3]); + EXPECT_EQ(h1, r[0]); + EXPECT_EQ(h2, r[1]); + EXPECT_EQ(h2, r[2]); + EXPECT_EQ(h1, r[3]); } TEST_F(LiteralUtilTest, LiteralSliceTest) { auto scalar = LiteralUtil::CreateR0(1.0); auto matrix = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); - auto tuple = LiteralUtil::MakeTuple({scalar.get(), matrix.get()}); - auto nested_tuple = LiteralUtil::MakeTuple({tuple.get(), scalar.get()}); + auto tuple = LiteralUtil::MakeTuple({&scalar, &matrix}); + auto nested_tuple = LiteralUtil::MakeTuple({&tuple, &scalar}); Literal nil(ShapeUtil::MakeNil()); - EXPECT_EQ(LiteralSlice(*scalar, {}), *scalar); - EXPECT_EQ(LiteralSlice(*matrix, {}), *matrix); - EXPECT_EQ(LiteralSlice(*tuple, {}), *tuple); - EXPECT_EQ(LiteralSlice(*nested_tuple, {}), *nested_tuple); + EXPECT_EQ(LiteralSlice(scalar, {}), scalar); + EXPECT_EQ(LiteralSlice(matrix, {}), matrix); + EXPECT_EQ(LiteralSlice(tuple, {}), tuple); + EXPECT_EQ(LiteralSlice(nested_tuple, {}), nested_tuple); EXPECT_EQ(LiteralSlice(nil, {}), nil); - EXPECT_EQ(LiteralSlice(*tuple, {0}), *scalar); - EXPECT_EQ(LiteralSlice(*tuple, {1}), *matrix); + EXPECT_EQ(LiteralSlice(tuple, {0}), scalar); + EXPECT_EQ(LiteralSlice(tuple, {1}), matrix); - EXPECT_EQ(LiteralSlice(*nested_tuple, {0}), *tuple); - EXPECT_EQ(LiteralSlice(*nested_tuple, {0, 0}), *scalar); - EXPECT_EQ(LiteralSlice(*nested_tuple, {0, 1}), *matrix); - EXPECT_EQ(LiteralSlice(*nested_tuple, {1}), *scalar); + EXPECT_EQ(LiteralSlice(nested_tuple, {0}), tuple); + EXPECT_EQ(LiteralSlice(nested_tuple, {0, 0}), scalar); + EXPECT_EQ(LiteralSlice(nested_tuple, {0, 1}), matrix); + EXPECT_EQ(LiteralSlice(nested_tuple, {1}), scalar); } TEST_F(LiteralUtilTest, MutatingLiteralSlice) { auto scalar = LiteralUtil::CreateR0(1.0); auto matrix = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); - auto tuple = LiteralUtil::MakeTuple({scalar.get(), matrix.get()}); - auto nested_tuple = LiteralUtil::MakeTuple({tuple.get(), scalar.get()}); + auto tuple = LiteralUtil::MakeTuple({&scalar, &matrix}); + auto nested_tuple = LiteralUtil::MakeTuple({&tuple, &scalar}); // Verify that changing the underlying data beneath the view changes the // data of the view itself. - const auto nested_tuple_view = LiteralSlice(*nested_tuple); - EXPECT_EQ( - nested_tuple->Get(/*multi_index=*/{}, /*shape_index=*/{0, 0}), - 1.0f); + const auto nested_tuple_view = LiteralSlice(nested_tuple); + EXPECT_EQ(nested_tuple.Get(/*multi_index=*/{}, /*shape_index=*/{0, 0}), + 1.0f); EXPECT_EQ(nested_tuple_view.Get(/*multi_index=*/{}, /*shape_index=*/{0, 0}), 1.0f); - nested_tuple->Set(/*multi_index=*/{}, /*shape_index=*/{0, 0}, 555.0f); - EXPECT_EQ( - nested_tuple->Get(/*multi_index=*/{}, /*shape_index=*/{0, 0}), - 555.0f); + nested_tuple.Set(/*multi_index=*/{}, /*shape_index=*/{0, 0}, 555.0f); + EXPECT_EQ(nested_tuple.Get(/*multi_index=*/{}, /*shape_index=*/{0, 0}), + 555.0f); EXPECT_EQ(nested_tuple_view.Get(/*multi_index=*/{}, /*shape_index=*/{0, 0}), 555.0f); @@ -1448,14 +1441,14 @@ TEST_F(LiteralUtilTest, MutatingLiteralSlice) { TEST_F(LiteralUtilTest, LiteralSliceOfALiteralSlice) { auto scalar = LiteralUtil::CreateR0(1.0); auto matrix = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); - auto tuple = LiteralUtil::MakeTuple({scalar.get(), matrix.get()}); - auto nested_tuple = LiteralUtil::MakeTuple({tuple.get(), scalar.get()}); + auto tuple = LiteralUtil::MakeTuple({&scalar, &matrix}); + auto nested_tuple = LiteralUtil::MakeTuple({&tuple, &scalar}); - const auto nested_tuple_view = LiteralSlice(*nested_tuple); + const auto nested_tuple_view = LiteralSlice(nested_tuple); const auto tuple_view = LiteralSlice(nested_tuple_view, /*view_root=*/{0}); const auto matrix_view = LiteralSlice(tuple_view, /*view_root=*/{1}); EXPECT_EQ(matrix_view, - *LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}})); + LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}})); } TEST_F(LiteralUtilTest, BorrowingLiteralFromOneBufferPtr) { @@ -1498,9 +1491,8 @@ TEST_F(LiteralUtilTest, BorrowingLiteralFromMultipleBufferPtrs) { } TEST_F(LiteralUtilTest, LiteralMove) { - std::unique_ptr matrix = - LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); - Literal literal(std::move(*matrix)); + Literal matrix = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + Literal literal(std::move(matrix)); EXPECT_TRUE( ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {2, 2}), literal.shape())); @@ -1512,17 +1504,21 @@ TEST_F(LiteralUtilTest, LiteralMove) { TEST_F(LiteralUtilTest, DecomposeTuple) { Literal nil_literal(ShapeUtil::MakeNil()); - auto nested_tuple = LiteralUtil::MakeTuple( - {LiteralUtil::CreateR2({{1, 2}, {3, 4}}).get(), - LiteralUtil::MakeTuple( - {LiteralUtil::CreateR0(42).get(), - LiteralUtil::CreateR1({23.0, 44.0}).get(), &nil_literal}) - .get(), - &nil_literal}); - - EXPECT_FALSE(ShapeUtil::IsNil(nested_tuple->shape())); - std::vector elements = nested_tuple->DecomposeTuple(); - EXPECT_TRUE(ShapeUtil::IsNil(nested_tuple->shape())); + Literal inner_elements[] = { + LiteralUtil::CreateR0(42), + LiteralUtil::CreateR1({23.0, 44.0}), + }; + Literal tuple_elements[] = { + LiteralUtil::CreateR2({{1, 2}, {3, 4}}), + LiteralUtil::MakeTuple( + {&inner_elements[0], &inner_elements[1], &nil_literal}), + }; + Literal nested_tuple = LiteralUtil::MakeTuple( + {&tuple_elements[0], &tuple_elements[1], &nil_literal}); + + EXPECT_FALSE(ShapeUtil::IsNil(nested_tuple.shape())); + std::vector elements = nested_tuple.DecomposeTuple(); + EXPECT_TRUE(ShapeUtil::IsNil(nested_tuple.shape())); ASSERT_EQ(elements.size(), 3); @@ -1553,15 +1549,15 @@ TEST_F(LiteralUtilTest, DecomposeEmptyTuple) { TEST_F(LiteralUtilTest, MoveIntoTuple) { std::vector elements; - elements.push_back(std::move(*LiteralUtil::CreateR0(1.0))); - elements.push_back(std::move(*LiteralUtil::CreateR1({4, 8}))); - elements.push_back(std::move(*LiteralUtil::MakeTuple( - {LiteralUtil::CreateR0(42).get(), - LiteralUtil::CreateR1({23.0, 44.0}).get()}) - - )); - - Literal literal = Literal::MoveIntoTuple(&elements); + elements.push_back(LiteralUtil::CreateR0(1.0)); + elements.push_back(LiteralUtil::CreateR1({4, 8})); + std::vector inner_elements; + inner_elements.push_back(LiteralUtil::CreateR0(42)); + inner_elements.push_back(LiteralUtil::CreateR1({23.0, 44.0})); + elements.push_back( + LiteralUtil::MakeTuple({&inner_elements[0], &inner_elements[1]})); + + Literal literal = Literal::MoveIntoTuple(absl::MakeSpan(elements)); ASSERT_TRUE(ShapeUtil::IsTuple(literal.shape())); ASSERT_EQ(ShapeUtil::TupleElementCount(literal.shape()), 3); @@ -1580,16 +1576,15 @@ TEST_F(LiteralUtilTest, MoveIntoTuple) { TEST_F(LiteralUtilTest, MoveIntoEmptyTuple) { Literal literal = Literal::MoveIntoTuple({}); ASSERT_TRUE(ShapeUtil::IsTuple(literal.shape())); - ASSERT_EQ(ShapeUtil::TupleElementCount(literal.shape()), 0); + EXPECT_EQ(ShapeUtil::TupleElementCount(literal.shape()), 0); } TEST_F(LiteralUtilTest, LiteralMoveAssignment) { Literal literal; EXPECT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeNil(), literal.shape())); - std::unique_ptr matrix = - LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); - literal = std::move(*matrix); + Literal matrix = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + literal = std::move(matrix); EXPECT_TRUE( ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {2, 2}), literal.shape())); @@ -1600,9 +1595,8 @@ TEST_F(LiteralUtilTest, LiteralMoveAssignment) { } TEST_F(LiteralUtilTest, LiteralSliceCopy) { - std::unique_ptr matrix = - LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); - const auto matrix_view = LiteralSlice(*matrix); + Literal matrix = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + const auto matrix_view = LiteralSlice(matrix); LiteralSlice matrix_view_copy(matrix_view); EXPECT_EQ(matrix_view_copy.Get({0, 0}), 1.0); @@ -1612,45 +1606,43 @@ TEST_F(LiteralUtilTest, LiteralSliceCopy) { } TEST_F(LiteralUtilTest, GetSetTuple) { - auto tuple = LiteralUtil::MakeTuple( - {LiteralUtil::CreateR0(42.0).get(), - LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}).get()}); - EXPECT_EQ(tuple->Get(/*multi_index=*/{}, /*shape_index=*/{0}), 42.0); - tuple->Set(/*multi_index=*/{}, /*shape_index=*/{0}, -5.0); - EXPECT_EQ(tuple->Get(/*multi_index=*/{}, /*shape_index=*/{0}), -5.0); - - EXPECT_EQ(tuple->Get(/*multi_index=*/{1, 0}, /*shape_index=*/{1}), - 3.0); - tuple->Set(/*multi_index=*/{1, 0}, /*shape_index=*/{1}, -4.0); - EXPECT_EQ(tuple->Get(/*multi_index=*/{1, 0}, /*shape_index=*/{1}), + Literal elements[] = { + LiteralUtil::CreateR0(42.0), + LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}), + }; + auto tuple = LiteralUtil::MakeTuple({&elements[0], &elements[1]}); + EXPECT_EQ(tuple.Get(/*multi_index=*/{}, /*shape_index=*/{0}), 42.0); + tuple.Set(/*multi_index=*/{}, /*shape_index=*/{0}, -5.0); + EXPECT_EQ(tuple.Get(/*multi_index=*/{}, /*shape_index=*/{0}), -5.0); + + EXPECT_EQ(tuple.Get(/*multi_index=*/{1, 0}, /*shape_index=*/{1}), 3.0); + tuple.Set(/*multi_index=*/{1, 0}, /*shape_index=*/{1}, -4.0); + EXPECT_EQ(tuple.Get(/*multi_index=*/{1, 0}, /*shape_index=*/{1}), -4.0); } TEST_F(LiteralUtilTest, CreateFromShapeZeroInitialized) { // Literals constructed using CreateFromShape should be zero initialized. - std::unique_ptr scalar_f32 = - Literal::CreateFromShape(ShapeUtil::MakeShape(F32, {})); - EXPECT_EQ(scalar_f32->Get({}), 0.0); - EXPECT_TRUE(scalar_f32->IsAll(0)); - - std::unique_ptr vector_s32 = - Literal::CreateFromShape(ShapeUtil::MakeShape(S32, {3})); - EXPECT_EQ(vector_s32->Get({0}), 0); - EXPECT_EQ(vector_s32->Get({1}), 0); - EXPECT_EQ(vector_s32->Get({2}), 0); - EXPECT_TRUE(vector_s32->IsAll(0)); - - std::unique_ptr tuple = - Literal::CreateFromShape(ShapeUtil::MakeTupleShape( - {ShapeUtil::MakeShape(F64, {}), ShapeUtil::MakeShape(PRED, {2}), - ShapeUtil::MakeShape(U64, {2, 1}), ShapeUtil::MakeShape(C64, {})})); - - EXPECT_EQ(tuple->Get({}, {0}), 0.0); - EXPECT_EQ(tuple->Get({0}, {1}), false); - EXPECT_EQ(tuple->Get({1}, {1}), false); - EXPECT_EQ(tuple->Get({0, 0}, {2}), 0); - EXPECT_EQ(tuple->Get({1, 0}, {2}), 0); - EXPECT_EQ(tuple->Get({}, {3}), complex64(0.0f, 0.0f)); + Literal scalar_f32 = Literal::CreateFromShape(ShapeUtil::MakeShape(F32, {})); + EXPECT_EQ(scalar_f32.Get({}), 0.0); + EXPECT_TRUE(scalar_f32.IsAll(0)); + + Literal vector_s32 = Literal::CreateFromShape(ShapeUtil::MakeShape(S32, {3})); + EXPECT_EQ(vector_s32.Get({0}), 0); + EXPECT_EQ(vector_s32.Get({1}), 0); + EXPECT_EQ(vector_s32.Get({2}), 0); + EXPECT_TRUE(vector_s32.IsAll(0)); + + Literal tuple = Literal::CreateFromShape(ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F64, {}), ShapeUtil::MakeShape(PRED, {2}), + ShapeUtil::MakeShape(U64, {2, 1}), ShapeUtil::MakeShape(C64, {})})); + + EXPECT_EQ(tuple.Get({}, {0}), 0.0); + EXPECT_EQ(tuple.Get({0}, {1}), false); + EXPECT_EQ(tuple.Get({1}, {1}), false); + EXPECT_EQ(tuple.Get({0, 0}, {2}), 0); + EXPECT_EQ(tuple.Get({1, 0}, {2}), 0); + EXPECT_EQ(tuple.Get({}, {3}), complex64(0.0f, 0.0f)); } TEST_F(LiteralUtilTest, ProtoRoundTrip) { @@ -1658,6 +1650,7 @@ TEST_F(LiteralUtilTest, ProtoRoundTrip) { auto one_f32 = LiteralUtil::CreateR0(1.0); auto two_f32 = LiteralUtil::CreateR0(2.0); auto vector_int8 = LiteralUtil::CreateR1({-128, 0, 2, 4, 7, 56, 127}); + auto vector_uint8 = LiteralUtil::CreateR1({128, 0, 2, 56, 127, 255}); auto vector_c64 = LiteralUtil::CreateR1({{1.0, 2.0}, {3.0, 4.0}}); auto vector_bfloat16 = LiteralUtil::CreateR1( {bfloat16{-1.0}, bfloat16{2.0}, bfloat16{-3.0}}); @@ -1666,25 +1659,27 @@ TEST_F(LiteralUtilTest, ProtoRoundTrip) { auto matrix_pred = LiteralUtil::CreateR2({{true, false, true}, {false, false, true}}); auto tuple = LiteralUtil::MakeTuple( - {one_f32.get(), vector_half.get(), matrix_pred.get(), matrix_pred.get()}); + {&one_f32, &vector_half, &matrix_pred, &matrix_pred}); Literal nil_literal(ShapeUtil::MakeNil()); - auto nested_tuple = LiteralUtil::MakeTuple( - {tuple.get(), vector_bfloat16.get(), tuple.get(), &nil_literal}); + auto nested_tuple = + LiteralUtil::MakeTuple({&tuple, &vector_bfloat16, &tuple, &nil_literal}); auto to_from_proto = [](const Literal& literal) -> Literal { - return std::move(*Literal::CreateFromProto(literal.ToProto()).ValueOrDie()); + return Literal::CreateFromProto(literal.ToProto()).ValueOrDie(); }; - EXPECT_EQ(*one_f32, to_from_proto(*one_f32)); - EXPECT_EQ(*vector_c64, to_from_proto(*vector_c64)); - EXPECT_EQ(*vector_bfloat16, to_from_proto(*vector_bfloat16)); - EXPECT_EQ(*matrix_pred, to_from_proto(*matrix_pred)); - EXPECT_EQ(*tuple, to_from_proto(*tuple)); - EXPECT_EQ(*nested_tuple, to_from_proto(*nested_tuple)); + EXPECT_EQ(one_f32, to_from_proto(one_f32)); + EXPECT_EQ(vector_int8, to_from_proto(vector_int8)); + EXPECT_EQ(vector_uint8, to_from_proto(vector_uint8)); + EXPECT_EQ(vector_c64, to_from_proto(vector_c64)); + EXPECT_EQ(vector_bfloat16, to_from_proto(vector_bfloat16)); + EXPECT_EQ(matrix_pred, to_from_proto(matrix_pred)); + EXPECT_EQ(tuple, to_from_proto(tuple)); + EXPECT_EQ(nested_tuple, to_from_proto(nested_tuple)); EXPECT_EQ(nil_literal, to_from_proto(nil_literal)); - EXPECT_NE(*one_f32, *two_f32); - EXPECT_NE(*one_f32, to_from_proto(*two_f32)); + EXPECT_NE(one_f32, two_f32); + EXPECT_NE(one_f32, to_from_proto(two_f32)); } TEST_F(LiteralUtilTest, InvalidProtoNoValues) { @@ -1693,7 +1688,7 @@ TEST_F(LiteralUtilTest, InvalidProtoNoValues) { *proto.mutable_shape() = ShapeUtil::MakeShape(F32, {3}); Status status = Literal::CreateFromProto(proto).status(); ASSERT_FALSE(status.ok()); - ASSERT_THAT(status.error_message(), + EXPECT_THAT(status.error_message(), HasSubstr("Expected 3 elements in LiteralProto")); } @@ -1705,7 +1700,7 @@ TEST_F(LiteralUtilTest, InvalidProtoNoShape) { proto.add_preds(false); Status status = Literal::CreateFromProto(proto).status(); ASSERT_FALSE(status.ok()); - ASSERT_THAT(status.error_message(), HasSubstr("LiteralProto has no shape")); + EXPECT_THAT(status.error_message(), HasSubstr("LiteralProto has no shape")); } TEST_F(LiteralUtilTest, InvalidProtoWrongContainer) { @@ -1717,7 +1712,7 @@ TEST_F(LiteralUtilTest, InvalidProtoWrongContainer) { proto.add_preds(false); Status status = Literal::CreateFromProto(proto).status(); ASSERT_FALSE(status.ok()); - ASSERT_THAT(status.error_message(), + EXPECT_THAT(status.error_message(), HasSubstr("Expected 3 elements in LiteralProto")); } @@ -1730,7 +1725,7 @@ TEST_F(LiteralUtilTest, InvalidProtoTooFewValues) { proto.add_f32s(3.0); Status status = Literal::CreateFromProto(proto).status(); ASSERT_FALSE(status.ok()); - ASSERT_THAT(status.error_message(), + EXPECT_THAT(status.error_message(), HasSubstr("Expected 84 elements in LiteralProto")); } @@ -1743,7 +1738,7 @@ TEST_F(LiteralUtilTest, InvalidProtoTooManyValues) { proto.add_s32s(100); Status status = Literal::CreateFromProto(proto).status(); ASSERT_FALSE(status.ok()); - ASSERT_THAT(status.error_message(), + EXPECT_THAT(status.error_message(), HasSubstr("Expected 2 elements in LiteralProto")); } @@ -1758,7 +1753,7 @@ TEST_F(LiteralUtilTest, InvalidProtoMissingLayout) { proto.add_preds(false); Status status = Literal::CreateFromProto(proto).status(); ASSERT_FALSE(status.ok()); - ASSERT_THAT(status.error_message(), HasSubstr("LiteralProto has no layout")); + EXPECT_THAT(status.error_message(), HasSubstr("LiteralProto has no layout")); } TEST_F(LiteralUtilTest, InvalidProtoTooFewTupleElements) { @@ -1774,7 +1769,7 @@ TEST_F(LiteralUtilTest, InvalidProtoTooFewTupleElements) { Status status = Literal::CreateFromProto(proto).status(); ASSERT_FALSE(status.ok()); - ASSERT_THAT(status.error_message(), HasSubstr("Expected 2 tuple elements")); + EXPECT_THAT(status.error_message(), HasSubstr("Expected 2 tuple elements")); } TEST_F(LiteralUtilTest, InvalidProtoTooManyTupleElements) { @@ -1797,17 +1792,17 @@ TEST_F(LiteralUtilTest, InvalidProtoTooManyTupleElements) { Status status = Literal::CreateFromProto(proto).status(); ASSERT_FALSE(status.ok()); - ASSERT_THAT(status.error_message(), HasSubstr("Expected 2 tuple elements")); + EXPECT_THAT(status.error_message(), HasSubstr("Expected 2 tuple elements")); } TEST_F(LiteralUtilTest, SortSparseElements) { auto literal = LiteralUtil::CreateSparse({10, 10, 10}, SparseIndexArray(10, 3), {}); - literal->AppendSparseElement({2, 3, 4}, 2.0); - literal->AppendSparseElement({3, 4, 5}, 3.0); - literal->AppendSparseElement({1, 2, 3}, 1.0); - literal->SortSparseElements(); - ASSERT_EQ(literal->ToString(false), + literal.AppendSparseElement({2, 3, 4}, 2.0); + literal.AppendSparseElement({3, 4, 5}, 3.0); + literal.AppendSparseElement({1, 2, 3}, 1.0); + literal.SortSparseElements(); + EXPECT_EQ(literal.ToString(false), "f32[10,10,10]{[1, 2, 3]: 1, [2, 3, 4]: 2, [3, 4, 5]: 3}"); } @@ -1815,59 +1810,56 @@ TEST_F(LiteralUtilTest, GetSparseElementAsString) { std::vector dimensions = {10, 10, 10}; SparseIndexArray indices(10, {{1, 2, 3}, {2, 3, 4}, {3, 4, 5}}); - ASSERT_EQ( + EXPECT_EQ( LiteralUtil::CreateSparse(dimensions, indices, {true, false, true}) - ->GetSparseElementAsString(1), + .GetSparseElementAsString(1), "false"); - ASSERT_EQ(LiteralUtil::CreateSparse(dimensions, indices, {1, 2, 3}) - ->GetSparseElementAsString(1), + EXPECT_EQ(LiteralUtil::CreateSparse(dimensions, indices, {1, 2, 3}) + .GetSparseElementAsString(1), absl::StrCat(int64{2})); - ASSERT_EQ( + EXPECT_EQ( LiteralUtil::CreateSparse(dimensions, indices, {1.0, 2.0, 3.0}) - ->GetSparseElementAsString(1), + .GetSparseElementAsString(1), absl::StrCat(double{2.0})); - ASSERT_EQ(LiteralUtil::CreateSparse(dimensions, indices, + EXPECT_EQ(LiteralUtil::CreateSparse(dimensions, indices, {half{1.0}, half{2.0}, half{3.0}}) - ->GetSparseElementAsString(1), + .GetSparseElementAsString(1), absl::StrCat(static_cast(half{2.0}))); - ASSERT_EQ(LiteralUtil::CreateSparse( + EXPECT_EQ(LiteralUtil::CreateSparse( dimensions, indices, std::vector{{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}}) - ->GetSparseElementAsString(1), + .GetSparseElementAsString(1), absl::StrCat("(", float{3.0}, ", ", float{4.0}, ")")); } TEST_F(LiteralUtilTest, BroadcastVectorToMatrix0) { - std::unique_ptr literal = LiteralUtil::CreateR1({1, 2}); + Literal literal = LiteralUtil::CreateR1({1, 2}); TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr broadcasted_literal, - literal->Broadcast( - /*result_shape=*/ShapeUtil::MakeShape(S64, {2, 2}), - /*dimensions=*/{0})); - EXPECT_EQ(*broadcasted_literal, - *LiteralUtil::CreateR2({{1, 1}, {2, 2}})); + Literal broadcasted_literal, + literal.Broadcast(/*result_shape=*/ShapeUtil::MakeShape(S64, {2, 2}), + /*dimensions=*/{0})); + EXPECT_EQ(broadcasted_literal, + LiteralUtil::CreateR2({{1, 1}, {2, 2}})); } TEST_F(LiteralUtilTest, BroadcastVectorToMatrix1) { - std::unique_ptr literal = LiteralUtil::CreateR1({1, 2}); + Literal literal = LiteralUtil::CreateR1({1, 2}); TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr broadcasted_literal, - literal->Broadcast( - /*result_shape=*/ShapeUtil::MakeShape(S64, {2, 2}), - /*dimensions=*/{1})); - EXPECT_EQ(*broadcasted_literal, - *LiteralUtil::CreateR2({{1, 2}, {1, 2}})); + Literal broadcasted_literal, + literal.Broadcast(/*result_shape=*/ShapeUtil::MakeShape(S64, {2, 2}), + /*dimensions=*/{1})); + EXPECT_EQ(broadcasted_literal, + LiteralUtil::CreateR2({{1, 2}, {1, 2}})); } TEST_F(LiteralUtilTest, BroadcastScalarToMatrix) { - std::unique_ptr literal = LiteralUtil::CreateR0(9); + Literal literal = LiteralUtil::CreateR0(9); TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr broadcasted_literal, - literal->Broadcast( - /*result_shape=*/ShapeUtil::MakeShape(S32, {2, 2}), - /*dimensions=*/{})); - EXPECT_EQ(*broadcasted_literal, - *LiteralUtil::CreateR2({{9, 9}, {9, 9}})); + Literal broadcasted_literal, + literal.Broadcast(/*result_shape=*/ShapeUtil::MakeShape(S32, {2, 2}), + /*dimensions=*/{})); + EXPECT_EQ(broadcasted_literal, + LiteralUtil::CreateR2({{9, 9}, {9, 9}})); } } // namespace diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc index 95d93acfe8a65dd6d19270fc1a496680585c984d..0cb1ae35f4ad31f091063d78ed32c1463be8ee0a 100644 --- a/tensorflow/compiler/xla/literal_util.cc +++ b/tensorflow/compiler/xla/literal_util.cc @@ -33,7 +33,6 @@ limitations under the License. #include "tensorflow/core/lib/core/casts.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/hash/hash.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/mem.h" #include "tensorflow/core/platform/types.h" @@ -46,7 +45,7 @@ using absl::StrCat; // Return a literal with all arrays of type FromNativeT converted to type // ToNativeT in the given literal. template -std::unique_ptr ConvertType(LiteralSlice literal) { +Literal ConvertType(LiteralSlice literal) { // First construct shape of the result. Shape result_shape(literal.shape()); ShapeUtil::ForEachMutableSubshape( @@ -57,7 +56,7 @@ std::unique_ptr ConvertType(LiteralSlice literal) { primitive_util::NativeToPrimitiveType()); } }); - auto result = absl::make_unique(result_shape); + Literal result(result_shape); // Then copy over the data from 'literal' converting FromNativeT values to // ToNativeT values as necessary. @@ -68,14 +67,14 @@ std::unique_ptr ConvertType(LiteralSlice literal) { if (subshape.element_type() == primitive_util::NativeToPrimitiveType()) { auto src = literal.data(shape_index); - auto dest = result->data(shape_index); + auto dest = result.data(shape_index); for (int64 i = 0; i < src.size(); ++i) { dest[i] = static_cast(src[i]); } } else { - TF_CHECK_OK(result->CopyFrom(literal, - /*dest_shape_index=*/shape_index, - /*src_shape_index=*/shape_index)); + TF_CHECK_OK(result.CopyFrom(literal, + /*dest_shape_index=*/shape_index, + /*src_shape_index=*/shape_index)); } } }); @@ -84,54 +83,52 @@ std::unique_ptr ConvertType(LiteralSlice literal) { } // namespace -/* static */ std::unique_ptr LiteralUtil::CreateFromDimensions( - PrimitiveType primitive_type, - tensorflow::gtl::ArraySlice dimensions) { +/* static */ Literal LiteralUtil::CreateFromDimensions( + PrimitiveType primitive_type, absl::Span dimensions) { return Literal::CreateFromShape( ShapeUtil::MakeShape(primitive_type, dimensions)); } -/* static */ std::unique_ptr LiteralUtil::ConvertBF16ToF32( +/* static */ Literal LiteralUtil::ConvertBF16ToF32( const LiteralSlice& bf16_literal) { return ConvertType(bf16_literal); } -/* static */ std::unique_ptr LiteralUtil::ConvertF32ToBF16( +/* static */ Literal LiteralUtil::ConvertF32ToBF16( const LiteralSlice& f32_literal) { return ConvertType(f32_literal); } -/* static */ std::unique_ptr LiteralUtil::CreateToken() { - return absl::make_unique(ShapeUtil::MakeTokenShape()); +/* static */ Literal LiteralUtil::CreateToken() { + return Literal(ShapeUtil::MakeTokenShape()); } /* static */ Literal LiteralUtil::Zero(PrimitiveType primitive_type) { switch (primitive_type) { case U8: - return std::move(*LiteralUtil::CreateR0(0)); + return LiteralUtil::CreateR0(0); case U32: - return std::move(*LiteralUtil::CreateR0(0)); + return LiteralUtil::CreateR0(0); case U64: - return std::move(*LiteralUtil::CreateR0(0)); + return LiteralUtil::CreateR0(0); case S8: - return std::move(*LiteralUtil::CreateR0(0)); + return LiteralUtil::CreateR0(0); case S32: - return std::move(*LiteralUtil::CreateR0(0)); + return LiteralUtil::CreateR0(0); case S64: - return std::move(*LiteralUtil::CreateR0(0)); + return LiteralUtil::CreateR0(0); case F16: - return std::move(*LiteralUtil::CreateR0(static_cast(0.0f))); + return LiteralUtil::CreateR0(static_cast(0.0f)); case BF16: - return std::move( - *LiteralUtil::CreateR0(static_cast(0.0f))); + return LiteralUtil::CreateR0(static_cast(0.0f)); case F32: - return std::move(*LiteralUtil::CreateR0(0)); + return LiteralUtil::CreateR0(0); case F64: - return std::move(*LiteralUtil::CreateR0(0)); + return LiteralUtil::CreateR0(0); case C64: - return std::move(*LiteralUtil::CreateR0(0)); + return LiteralUtil::CreateR0(0); case PRED: - return std::move(*LiteralUtil::CreateR0(false)); + return LiteralUtil::CreateR0(false); case S16: case U16: LOG(FATAL) << "u16/s16 literals not yet implemented"; @@ -147,30 +144,29 @@ std::unique_ptr ConvertType(LiteralSlice literal) { /* static */ Literal LiteralUtil::One(PrimitiveType primitive_type) { switch (primitive_type) { case U8: - return std::move(*LiteralUtil::CreateR0(1)); + return LiteralUtil::CreateR0(1); case U32: - return std::move(*LiteralUtil::CreateR0(1)); + return LiteralUtil::CreateR0(1); case U64: - return std::move(*LiteralUtil::CreateR0(1)); + return LiteralUtil::CreateR0(1); case S8: - return std::move(*LiteralUtil::CreateR0(1)); + return LiteralUtil::CreateR0(1); case S32: - return std::move(*LiteralUtil::CreateR0(1)); + return LiteralUtil::CreateR0(1); case S64: - return std::move(*LiteralUtil::CreateR0(1)); + return LiteralUtil::CreateR0(1); case F16: - return std::move(*LiteralUtil::CreateR0(static_cast(1.0f))); + return LiteralUtil::CreateR0(static_cast(1.0f)); case BF16: - return std::move( - *LiteralUtil::CreateR0(static_cast(1.0f))); + return LiteralUtil::CreateR0(static_cast(1.0f)); case F32: - return std::move(*LiteralUtil::CreateR0(1)); + return LiteralUtil::CreateR0(1); case F64: - return std::move(*LiteralUtil::CreateR0(1)); + return LiteralUtil::CreateR0(1); case C64: - return std::move(*LiteralUtil::CreateR0(1)); + return LiteralUtil::CreateR0(1); case PRED: - return std::move(*LiteralUtil::CreateR0(true)); + return LiteralUtil::CreateR0(true); case S16: case U16: LOG(FATAL) << "u16/s16 literals not yet implemented"; @@ -186,42 +182,36 @@ std::unique_ptr ConvertType(LiteralSlice literal) { /* static */ Literal LiteralUtil::MinValue(PrimitiveType primitive_type) { switch (primitive_type) { case U8: - return std::move( - *LiteralUtil::CreateR0(std::numeric_limits::min())); + return LiteralUtil::CreateR0(std::numeric_limits::min()); case U32: - return std::move( - *LiteralUtil::CreateR0(std::numeric_limits::min())); + return LiteralUtil::CreateR0(std::numeric_limits::min()); case U64: - return std::move( - *LiteralUtil::CreateR0(std::numeric_limits::min())); + return LiteralUtil::CreateR0(std::numeric_limits::min()); case S8: - return std::move( - *LiteralUtil::CreateR0(std::numeric_limits::min())); + return LiteralUtil::CreateR0(std::numeric_limits::min()); case S32: - return std::move( - *LiteralUtil::CreateR0(std::numeric_limits::min())); + return LiteralUtil::CreateR0(std::numeric_limits::min()); case S64: - return std::move( - *LiteralUtil::CreateR0(std::numeric_limits::min())); + return LiteralUtil::CreateR0(std::numeric_limits::min()); case F32: - return std::move(*LiteralUtil::CreateR0( - -std::numeric_limits::infinity())); + return LiteralUtil::CreateR0( + -std::numeric_limits::infinity()); case F64: - return std::move(*LiteralUtil::CreateR0( - -std::numeric_limits::infinity())); + return LiteralUtil::CreateR0( + -std::numeric_limits::infinity()); case C64: LOG(FATAL) << "C64 element type has no minimum value"; case PRED: - return std::move(*LiteralUtil::CreateR0(false)); + return LiteralUtil::CreateR0(false); case S16: case U16: LOG(FATAL) << "u16/s16 literals not yet implemented"; case F16: - return std::move(*LiteralUtil::CreateR0( - static_cast(-std::numeric_limits::infinity()))); + return LiteralUtil::CreateR0( + static_cast(-std::numeric_limits::infinity())); case BF16: - return std::move(*LiteralUtil::CreateR0( - static_cast(-std::numeric_limits::infinity()))); + return LiteralUtil::CreateR0( + static_cast(-std::numeric_limits::infinity())); case TUPLE: LOG(FATAL) << "tuple element type has no minimum value"; case OPAQUE: @@ -234,40 +224,34 @@ std::unique_ptr ConvertType(LiteralSlice literal) { /* static */ Literal LiteralUtil::MaxValue(PrimitiveType primitive_type) { switch (primitive_type) { case U8: - return std::move( - *LiteralUtil::CreateR0(std::numeric_limits::max())); + return LiteralUtil::CreateR0(std::numeric_limits::max()); case U32: - return std::move( - *LiteralUtil::CreateR0(std::numeric_limits::max())); + return LiteralUtil::CreateR0(std::numeric_limits::max()); case U64: - return std::move( - *LiteralUtil::CreateR0(std::numeric_limits::max())); + return LiteralUtil::CreateR0(std::numeric_limits::max()); case S8: - return std::move( - *LiteralUtil::CreateR0(std::numeric_limits::max())); + return LiteralUtil::CreateR0(std::numeric_limits::max()); case S32: - return std::move( - *LiteralUtil::CreateR0(std::numeric_limits::max())); + return LiteralUtil::CreateR0(std::numeric_limits::max()); case S64: - return std::move( - *LiteralUtil::CreateR0(std::numeric_limits::max())); + return LiteralUtil::CreateR0(std::numeric_limits::max()); case F32: - return std::move(*LiteralUtil::CreateR0( - std::numeric_limits::infinity())); + return LiteralUtil::CreateR0( + std::numeric_limits::infinity()); case F64: - return std::move(*LiteralUtil::CreateR0( - std::numeric_limits::infinity())); + return LiteralUtil::CreateR0( + std::numeric_limits::infinity()); case PRED: - return std::move(*LiteralUtil::CreateR0(true)); + return LiteralUtil::CreateR0(true); case S16: case U16: LOG(FATAL) << "u16/s16 literals not yet implemented"; case F16: - return std::move(*LiteralUtil::CreateR0( - static_cast(std::numeric_limits::infinity()))); + return LiteralUtil::CreateR0( + static_cast(std::numeric_limits::infinity())); case BF16: - return std::move(*LiteralUtil::CreateR0( - static_cast(std::numeric_limits::infinity()))); + return LiteralUtil::CreateR0( + static_cast(std::numeric_limits::infinity())); case TUPLE: LOG(FATAL) << "tuple element type has no maximum value"; case OPAQUE: @@ -277,34 +261,31 @@ std::unique_ptr ConvertType(LiteralSlice literal) { } } -/* static */ std::unique_ptr LiteralUtil::CreateR1( +/* static */ Literal LiteralUtil::CreateR1( const tensorflow::core::Bitmap& values) { - auto literal = absl::make_unique( + Literal literal( ShapeUtil::MakeShape(PRED, {static_cast(values.bits())})); - literal->PopulateR1(values); + literal.PopulateR1(values); return literal; } -/* static */ std::unique_ptr LiteralUtil::CreateR1U8( - absl::string_view value) { - auto literal = absl::make_unique( - ShapeUtil::MakeShape(U8, {static_cast(value.size())})); +/* static */ Literal LiteralUtil::CreateR1U8(absl::string_view value) { + Literal literal(ShapeUtil::MakeShape(U8, {static_cast(value.size())})); for (int i = 0; i < value.size(); ++i) { - literal->Set({i}, value[i]); + literal.Set({i}, value[i]); } return literal; } -/* static */ std::unique_ptr LiteralUtil::CreateR2F32Linspace( - float from, float to, int64 rows, int64 cols) { +/* static */ Literal LiteralUtil::CreateR2F32Linspace(float from, float to, + int64 rows, int64 cols) { auto value = MakeLinspaceArray2D(from, to, rows, cols); return CreateR2FromArray2D(*value); } -/* static */ std::unique_ptr LiteralUtil::ReshapeSlice( - tensorflow::gtl::ArraySlice new_dimensions, - tensorflow::gtl::ArraySlice minor_to_major, - const LiteralSlice& literal) { +/* static */ Literal LiteralUtil::ReshapeSlice( + absl::Span new_dimensions, + absl::Span minor_to_major, const LiteralSlice& literal) { int64 new_num_elements = 1; for (int64 i = 0; i < new_dimensions.size(); ++i) { new_num_elements *= new_dimensions[i]; @@ -312,13 +293,13 @@ std::unique_ptr ConvertType(LiteralSlice literal) { CHECK_EQ(ShapeUtil::ElementsIn(literal.shape()), new_num_elements); CHECK_EQ(new_dimensions.size(), minor_to_major.size()); - auto new_literal = absl::make_unique( + Literal new_literal( ShapeUtil::MakeShape(literal.shape().element_type(), new_dimensions)); // Create a new shape with the given minor-to-major layout. This shape is used // solely for converting linear address to multi-dimensional addresses when // writing elements to the new literal. - Shape shape_with_layout = new_literal->shape(); + Shape shape_with_layout = new_literal.shape(); *shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout(minor_to_major); // Copy data into new literal, element-by-element. @@ -329,40 +310,40 @@ std::unique_ptr ConvertType(LiteralSlice literal) { IndexUtil::LinearIndexToMultidimensionalIndex(shape_with_layout, i); switch (literal.shape().element_type()) { case PRED: - new_literal->Set(to_multi_index, - literal.Get(from_multi_index)); + new_literal.Set(to_multi_index, + literal.Get(from_multi_index)); break; case U8: - new_literal->Set(to_multi_index, - literal.Get(from_multi_index)); + new_literal.Set(to_multi_index, + literal.Get(from_multi_index)); break; case U32: - new_literal->Set(to_multi_index, - literal.Get(from_multi_index)); + new_literal.Set(to_multi_index, + literal.Get(from_multi_index)); break; case S32: - new_literal->Set(to_multi_index, - literal.Get(from_multi_index)); + new_literal.Set(to_multi_index, + literal.Get(from_multi_index)); break; case U64: - new_literal->Set(to_multi_index, - literal.Get(from_multi_index)); + new_literal.Set(to_multi_index, + literal.Get(from_multi_index)); break; case S64: - new_literal->Set(to_multi_index, - literal.Get(from_multi_index)); + new_literal.Set(to_multi_index, + literal.Get(from_multi_index)); break; case F32: - new_literal->Set(to_multi_index, - literal.Get(from_multi_index)); + new_literal.Set(to_multi_index, + literal.Get(from_multi_index)); break; case F64: - new_literal->Set(to_multi_index, - literal.Get(from_multi_index)); + new_literal.Set(to_multi_index, + literal.Get(from_multi_index)); break; case C64: - new_literal->Set(to_multi_index, - literal.Get(from_multi_index)); + new_literal.Set(to_multi_index, + literal.Get(from_multi_index)); break; default: LOG(FATAL) << "Unhandled primitive element type: " @@ -379,103 +360,88 @@ std::unique_ptr ConvertType(LiteralSlice literal) { CHECK_GT(ShapeUtil::ElementsIn(literal.shape()), 0); switch (literal.shape().element_type()) { case PRED: - return std::move( - *LiteralUtil::CreateR0(literal.GetFirstElement())); + return LiteralUtil::CreateR0(literal.GetFirstElement()); // 8 bit types. case S8: - return std::move( - *LiteralUtil::CreateR0(literal.GetFirstElement())); + return LiteralUtil::CreateR0(literal.GetFirstElement()); case U8: - return std::move( - *LiteralUtil::CreateR0(literal.GetFirstElement())); + return LiteralUtil::CreateR0(literal.GetFirstElement()); // 16 bit types. case BF16: - return std::move(*LiteralUtil::CreateR0( - literal.GetFirstElement())); + return LiteralUtil::CreateR0( + literal.GetFirstElement()); case F16: - return std::move( - *LiteralUtil::CreateR0(literal.GetFirstElement())); + return LiteralUtil::CreateR0(literal.GetFirstElement()); case S16: - return std::move( - *LiteralUtil::CreateR0(literal.GetFirstElement())); + return LiteralUtil::CreateR0(literal.GetFirstElement()); case U16: - return std::move( - *LiteralUtil::CreateR0(literal.GetFirstElement())); + return LiteralUtil::CreateR0(literal.GetFirstElement()); // 32 bit types. case F32: - return std::move( - *LiteralUtil::CreateR0(literal.GetFirstElement())); + return LiteralUtil::CreateR0(literal.GetFirstElement()); case S32: - return std::move( - *LiteralUtil::CreateR0(literal.GetFirstElement())); + return LiteralUtil::CreateR0(literal.GetFirstElement()); case U32: - return std::move( - *LiteralUtil::CreateR0(literal.GetFirstElement())); + return LiteralUtil::CreateR0(literal.GetFirstElement()); // 64 bit types. case C64: - return std::move(*LiteralUtil::CreateR0( - literal.GetFirstElement())); + return LiteralUtil::CreateR0( + literal.GetFirstElement()); case F64: - return std::move( - *LiteralUtil::CreateR0(literal.GetFirstElement())); + return LiteralUtil::CreateR0(literal.GetFirstElement()); case S64: - return std::move( - *LiteralUtil::CreateR0(literal.GetFirstElement())); + return LiteralUtil::CreateR0(literal.GetFirstElement()); case U64: - return std::move( - *LiteralUtil::CreateR0(literal.GetFirstElement())); + return LiteralUtil::CreateR0(literal.GetFirstElement()); default: LOG(FATAL) << "Unhandled primitive type " << literal.shape().element_type(); } } -/* static */ std::unique_ptr LiteralUtil::MakeTuple( - tensorflow::gtl::ArraySlice elements) { +/* static */ Literal LiteralUtil::MakeTuple( + absl::Span elements) { std::vector element_shapes; for (const auto* element : elements) { element_shapes.push_back(element->shape()); } - auto literal = - absl::make_unique(ShapeUtil::MakeTupleShape(element_shapes)); + Literal literal(ShapeUtil::MakeTupleShape(element_shapes)); for (int i = 0; i < elements.size(); ++i) { - TF_CHECK_OK(literal->CopyFrom(*elements[i], /*dest_shape_index=*/{i})); + TF_CHECK_OK(literal.CopyFrom(*elements[i], /*dest_shape_index=*/{i})); } return literal; } -/* static */ std::unique_ptr LiteralUtil::MakeTupleFromSlices( - tensorflow::gtl::ArraySlice elements) { +/* static */ Literal LiteralUtil::MakeTupleFromSlices( + absl::Span elements) { std::vector element_shapes; for (const auto& element : elements) { element_shapes.push_back(element.shape()); } - auto literal = - absl::make_unique(ShapeUtil::MakeTupleShape(element_shapes)); + Literal literal(ShapeUtil::MakeTupleShape(element_shapes)); for (int i = 0; i < elements.size(); ++i) { - TF_CHECK_OK(literal->CopyFrom(elements[i], /*dest_shape_index=*/{i})); + TF_CHECK_OK(literal.CopyFrom(elements[i], /*dest_shape_index=*/{i})); } return literal; } -/* static */ std::unique_ptr LiteralUtil::MakeTupleOwned( - std::vector> elements) { +/* static */ Literal LiteralUtil::MakeTupleOwned( + std::vector elements) { std::vector element_shapes; element_shapes.reserve(elements.size()); for (const auto& element : elements) { - element_shapes.push_back(element->shape()); + element_shapes.push_back(element.shape()); } - auto literal = - absl::make_unique(ShapeUtil::MakeTupleShape(element_shapes)); + Literal literal(ShapeUtil::MakeTupleShape(element_shapes)); for (int64 i = 0; i < elements.size(); ++i) { TF_CHECK_OK( - literal->MoveFrom(std::move(*elements[i]), /*dest_shape_index=*/{i})); + literal.MoveFrom(std::move(elements[i]), /*dest_shape_index=*/{i})); } return literal; } /* static */ string LiteralUtil::MultiIndexAsString( - tensorflow::gtl::ArraySlice multi_index) { + absl::Span multi_index) { return StrCat("{", absl::StrJoin(multi_index, ","), "}"); } diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h index 3d28c070f29052f2686cf605e068deadd998719c..2b181621ed92be8952ccec19e0d4229c494b9f47 100644 --- a/tensorflow/compiler/xla/literal_util.h +++ b/tensorflow/compiler/xla/literal_util.h @@ -29,6 +29,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/array4d.h" @@ -44,7 +45,6 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/bitmap.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/protobuf.h" @@ -69,37 +69,34 @@ class LiteralUtil { // The variants not ending with WithLayout use the default XLA layout for the // literal's linear representation in memory. template - static std::unique_ptr CreateR0(NativeT value); + static Literal CreateR0(NativeT value); template - static std::unique_ptr CreateR1( - tensorflow::gtl::ArraySlice values); - static std::unique_ptr CreateR1( - const tensorflow::core::Bitmap& values); + static Literal CreateR1(absl::Span values); + static Literal CreateR1(const tensorflow::core::Bitmap& values); template - static std::unique_ptr CreateR2( + static Literal CreateR2( std::initializer_list> values); template - static std::unique_ptr CreateR2WithLayout( + static Literal CreateR2WithLayout( std::initializer_list> values, const Layout& layout); template - static std::unique_ptr CreateR3( - std::initializer_list< - std::initializer_list>> - values); + static Literal CreateR3(std::initializer_list< + std::initializer_list>> + values); template - static std::unique_ptr CreateR3WithLayout( + static Literal CreateR3WithLayout( std::initializer_list< std::initializer_list>> values, const Layout& layout); template - static std::unique_ptr CreateR4( + static Literal CreateR4( std::initializer_list>>> values); template - static std::unique_ptr CreateR4WithLayout( + static Literal CreateR4WithLayout( std::initializer_list>>> values, @@ -140,9 +137,10 @@ class LiteralUtil { // [9, 10, 11]: 4.0 // template - static std::unique_ptr CreateSparse( - tensorflow::gtl::ArraySlice dimensions, SparseIndexArray indices, - tensorflow::gtl::ArraySlice values, bool sort = true); + static Literal CreateSparse(absl::Span dimensions, + SparseIndexArray indices, + absl::Span values, + bool sort = true); // Creates a scalar literal value zero of the given primitive type. static Literal Zero(PrimitiveType primitive_type); @@ -156,132 +154,120 @@ class LiteralUtil { static Literal MaxValue(PrimitiveType primitive_type); // Creates a literal of the given shape where each element is `value`. template - static std::unique_ptr CreateFullWithDescendingLayout( - tensorflow::gtl::ArraySlice dimensions, NativeT value); + static Literal CreateFullWithDescendingLayout( + absl::Span dimensions, NativeT value); // Creates a new literal from an Array type. The variants not ending with // WithLayout use the default XLA layout for the literal's linear // representation in memory. template - static std::unique_ptr CreateFromArray(const Array& values); + static Literal CreateFromArray(const Array& values); template - static std::unique_ptr CreateFromArrayWithLayout( - const Array& values, const Layout& layout); + static Literal CreateFromArrayWithLayout(const Array& values, + const Layout& layout); template - static std::unique_ptr CreateR2FromArray2D( - const Array2D& values); + static Literal CreateR2FromArray2D(const Array2D& values); template - static std::unique_ptr CreateR2FromArray2DWithLayout( - const Array2D& values, const Layout& layout); + static Literal CreateR2FromArray2DWithLayout(const Array2D& values, + const Layout& layout); template - static std::unique_ptr CreateR3FromArray3D( - const Array3D& values); + static Literal CreateR3FromArray3D(const Array3D& values); template - static std::unique_ptr CreateR3FromArray3DWithLayout( - const Array3D& values, const Layout& layout); + static Literal CreateR3FromArray3DWithLayout(const Array3D& values, + const Layout& layout); template - static std::unique_ptr CreateR4FromArray4D( - const Array4D& values); + static Literal CreateR4FromArray4D(const Array4D& values); template - static std::unique_ptr CreateR4FromArray4DWithLayout( - const Array4D& values, const Layout& layout); + static Literal CreateR4FromArray4DWithLayout(const Array4D& values, + const Layout& layout); // Creates a new vector of U8s literal value from a string. - static std::unique_ptr CreateR1U8(absl::string_view value); + static Literal CreateR1U8(absl::string_view value); // Creates a linspace-populated literal with the given number of rows and // columns. - static std::unique_ptr CreateR2F32Linspace(float from, float to, - int64 rows, int64 cols); + static Literal CreateR2F32Linspace(float from, float to, int64 rows, + int64 cols); // Creates a literal that projects the (x, y) dimensions given in values into // the z dimension given by "projection". template - static std::unique_ptr CreateR3Projected( + static Literal CreateR3Projected( std::initializer_list> values, int64 projection); // Creates a literal that projects the (x, y) dimensions given in values into // the z and p dimensions given. template - static std::unique_ptr CreateR4Projected( + static Literal CreateR4Projected( std::initializer_list> values, int64 projection_p, int64 projection_z); // Returns an identity matrix (rank 2) with the given row and column count. template - static std::unique_ptr MakeIdentityR2(int64 size); + static Literal MakeIdentityR2(int64 size); // Returns a tuple literal composed of given literals. Data is copied from the // given elements into the returned literal. - static std::unique_ptr MakeTuple( - tensorflow::gtl::ArraySlice elements); + static Literal MakeTuple(absl::Span elements); - static std::unique_ptr MakeTupleFromSlices( - tensorflow::gtl::ArraySlice elements); + static Literal MakeTupleFromSlices(absl::Span elements); // As above, but intended to be invoked with move semantics; i.e. // - // std::vector> elements = ...; + // std::vector elements = ...; // auto result = LiteralUtil::MakeTupleOwned(std::move(elements)); // // This would have been declared as an overload, but there is ambiguity // in invocation between the above signature and this one. - static std::unique_ptr MakeTupleOwned( - std::vector> elements); + static Literal MakeTupleOwned(std::vector elements); - // This overload lets you pass a braced list of unique_ptrs to + // This overload lets you pass a braced list of Literals to // MakeTupleOwned: // // LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1(...), ...). // - // Simply relying on the MakeTupleOwned(std::vector>) + // Simply relying on the MakeTupleOwned(std::vector) // overload doesn't work because std::initializer_list's elements are always // const. // - // The arguments to this function must all be unique_ptr. + // The arguments to this function must all be Literal. template - static std::unique_ptr MakeTupleOwned( - std::unique_ptr... elements) { - std::array, sizeof...(Ts)> arr{ - std::move(elements)...}; - std::vector> v; + static Literal MakeTupleOwned(Ts... elements) { + std::array arr{std::move(elements)...}; + std::vector v; v.insert(v.begin(), std::make_move_iterator(arr.begin()), std::make_move_iterator(arr.end())); return MakeTupleOwned(std::move(v)); } // Create a constant token literal. Token types have no value. - static std::unique_ptr CreateToken(); + static Literal CreateToken(); // Creates a new Literal object with its values havings the primitive_type // type, and with dimensions defined by the dimensions parameter. // The content of the literal values is the default value of the primitive // type of literal itself (0 for numeric types, and false for predicates). - static std::unique_ptr CreateFromDimensions( - PrimitiveType primitive_type, - tensorflow::gtl::ArraySlice dimensions); + static Literal CreateFromDimensions(PrimitiveType primitive_type, + absl::Span dimensions); // If the given literal's data type is bfloat16, converts it to a float // literal; otherwise, returns a copy of it. If the literal is a tuple, // recursively converts its elements. - static std::unique_ptr ConvertBF16ToF32( - const LiteralSlice& bf16_literal); + static Literal ConvertBF16ToF32(const LiteralSlice& bf16_literal); // If the given literal's data type is float, converts it to a bfloat16 // literal; otherwise, returns a copy of it. If the literal is a tuple, // recursively converts its elements. - static std::unique_ptr ConvertF32ToBF16( - const LiteralSlice& f32_literal); + static Literal ConvertF32ToBF16(const LiteralSlice& f32_literal); // Creates a literal with a new shape with the given new dimensions using the // data in the given input literal. For reshaping purposes the (flat) data // buffer of the input literal is assumed to have the given minor_to_major // layout order. - static std::unique_ptr ReshapeSlice( - tensorflow::gtl::ArraySlice new_dimensions, - tensorflow::gtl::ArraySlice minor_to_major, - const LiteralSlice& literal); + static Literal ReshapeSlice(absl::Span new_dimensions, + absl::Span minor_to_major, + const LiteralSlice& literal); // Creates a literal with the supplied shape, and uses the provided value // generator to populate the literal's values. @@ -289,9 +275,9 @@ class LiteralUtil { template < PrimitiveType type, typename T = typename primitive_util::PrimitiveTypeToNative::type> - static StatusOr> CreateRandomLiteral( + static StatusOr CreateRandomLiteral( const Shape& shape, - const std::function)>& generator); + const std::function)>& generator); // Creates a literal with the supplied shape, and initializes the literal // values using a normal distribution with given mean and stddev standard @@ -300,8 +286,8 @@ class LiteralUtil { template < PrimitiveType type, typename E, typename T = typename primitive_util::PrimitiveTypeToNative::type> - static StatusOr> CreateRandomLiteral( - const Shape& shape, E* engine, T mean, T stddev); + static StatusOr CreateRandomLiteral(const Shape& shape, E* engine, + T mean, T stddev); // Creates a literal with the supplied shape, and initializes the literal // values using a normal distribution with given mean and stddev standard @@ -310,8 +296,8 @@ class LiteralUtil { template < PrimitiveType type, typename T = typename primitive_util::PrimitiveTypeToNative::type> - static StatusOr> CreateRandomLiteral( - const Shape& shape, T mean, T stddev); + static StatusOr CreateRandomLiteral(const Shape& shape, T mean, + T stddev); // // End of factory methods. @@ -319,51 +305,49 @@ class LiteralUtil { // Returns a multi-dimensional index as a string. For example: '{7, 8}' will // be returned for a 2-dimensional index with dimension 0 index equal to 7, // dimension 1 equal to 8. - static string MultiIndexAsString( - tensorflow::gtl::ArraySlice multi_index); + static string MultiIndexAsString(absl::Span multi_index); }; std::ostream& operator<<(std::ostream& out, const Literal& literal); template -/* static */ std::unique_ptr LiteralUtil::CreateR0(NativeT value) { - auto literal = absl::make_unique(ShapeUtil::MakeShape( +/* static */ Literal LiteralUtil::CreateR0(NativeT value) { + Literal literal(ShapeUtil::MakeShape( primitive_util::NativeToPrimitiveType(), {})); - literal->Set({}, value); + literal.Set({}, value); return literal; } template -/* static */ std::unique_ptr LiteralUtil::CreateR1( - tensorflow::gtl::ArraySlice values) { - auto literal = absl::make_unique( +/* static */ Literal LiteralUtil::CreateR1(absl::Span values) { + Literal literal( ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType(), {static_cast(values.size())})); - literal->PopulateR1(values); + literal.PopulateR1(values); return literal; } template -/* static */ std::unique_ptr LiteralUtil::CreateR2WithLayout( +/* static */ Literal LiteralUtil::CreateR2WithLayout( std::initializer_list> values, const Layout& layout) { - auto literal = absl::make_unique(ShapeUtil::MakeShapeWithLayout( + Literal literal(ShapeUtil::MakeShapeWithLayout( primitive_util::NativeToPrimitiveType(), {static_cast(values.size()), static_cast(values.begin()->size())}, AsInt64Slice(layout.minor_to_major()))); - literal->PopulateR2(values); + literal.PopulateR2(values); return literal; } template -/* static */ std::unique_ptr LiteralUtil::CreateR2( +/* static */ Literal LiteralUtil::CreateR2( std::initializer_list> values) { return CreateR2WithLayout(values, LayoutUtil::GetDefaultLayoutForR2()); } template -/* static */ std::unique_ptr LiteralUtil::CreateR3WithLayout( +/* static */ Literal LiteralUtil::CreateR3WithLayout( std::initializer_list>> values, const Layout& layout) { @@ -388,14 +372,14 @@ template } template -/* static */ std::unique_ptr LiteralUtil::CreateR3( +/* static */ Literal LiteralUtil::CreateR3( std::initializer_list>> values) { return CreateR3WithLayout(values, LayoutUtil::GetDefaultLayoutForR3()); } template -/* static */ std::unique_ptr LiteralUtil::CreateR4WithLayout( +/* static */ Literal LiteralUtil::CreateR4WithLayout( std::initializer_list>>> values, @@ -426,23 +410,22 @@ template } template -/* static */ std::unique_ptr LiteralUtil::CreateSparse( - tensorflow::gtl::ArraySlice dimensions, SparseIndexArray indices, - tensorflow::gtl::ArraySlice values, bool sort) { +/* static */ Literal LiteralUtil::CreateSparse( + absl::Span dimensions, SparseIndexArray indices, + absl::Span values, bool sort) { int64 num_elements = values.size(); int64 rank = dimensions.size(); CHECK_EQ(num_elements, indices.index_count()); CHECK_EQ(rank, indices.rank()); - auto literal = - absl::make_unique(ShapeUtil::MakeShapeWithSparseLayout( - primitive_util::NativeToPrimitiveType(), dimensions, - indices.max_indices())); - literal->PopulateSparse(indices, values, sort); + Literal literal(ShapeUtil::MakeShapeWithSparseLayout( + primitive_util::NativeToPrimitiveType(), dimensions, + indices.max_indices())); + literal.PopulateSparse(indices, values, sort); return literal; } template -/* static */ std::unique_ptr LiteralUtil::CreateR4( +/* static */ Literal LiteralUtil::CreateR4( std::initializer_list>>> values) { @@ -450,50 +433,48 @@ template } template -/* static */ std::unique_ptr LiteralUtil::CreateFromArrayWithLayout( +/* static */ Literal LiteralUtil::CreateFromArrayWithLayout( const Array& values, const Layout& layout) { - auto literal = absl::make_unique(ShapeUtil::MakeShapeWithLayout( + Literal literal(ShapeUtil::MakeShapeWithLayout( primitive_util::NativeToPrimitiveType(), values.dimensions(), AsInt64Slice(layout.minor_to_major()))); - literal->PopulateFromArray(values); + literal.PopulateFromArray(values); return literal; } template -/* static */ std::unique_ptr LiteralUtil::CreateFromArray( +/* static */ Literal LiteralUtil::CreateFromArray( const Array& values) { return CreateFromArrayWithLayout( values, LayoutUtil::GetDefaultLayoutForRank(values.num_dimensions())); } template -/* static */ std::unique_ptr -LiteralUtil::CreateR2FromArray2DWithLayout(const Array2D& values, - const Layout& layout) { +/* static */ Literal LiteralUtil::CreateR2FromArray2DWithLayout( + const Array2D& values, const Layout& layout) { return CreateFromArrayWithLayout(values, layout); } template -/* static */ std::unique_ptr LiteralUtil::CreateR2FromArray2D( +/* static */ Literal LiteralUtil::CreateR2FromArray2D( const Array2D& values) { return CreateFromArray(values); } template -/* static */ std::unique_ptr -LiteralUtil::CreateR3FromArray3DWithLayout(const Array3D& values, - const Layout& layout) { +/* static */ Literal LiteralUtil::CreateR3FromArray3DWithLayout( + const Array3D& values, const Layout& layout) { return CreateFromArrayWithLayout(values, layout); } template -/* static */ std::unique_ptr LiteralUtil::CreateR3FromArray3D( +/* static */ Literal LiteralUtil::CreateR3FromArray3D( const Array3D& values) { return CreateFromArray(values); } template -/* static */ std::unique_ptr LiteralUtil::CreateR3Projected( +/* static */ Literal LiteralUtil::CreateR3Projected( std::initializer_list> values, int64 projection) { int64 dim0_size = projection; @@ -518,7 +499,7 @@ template } template -/* static */ std::unique_ptr LiteralUtil::CreateR4Projected( +/* static */ Literal LiteralUtil::CreateR4Projected( std::initializer_list> values, int64 projection_p, int64 projection_z) { int64 dim0_size = projection_p; @@ -546,21 +527,20 @@ template } template -/* static */ std::unique_ptr LiteralUtil::CreateR4FromArray4D( +/* static */ Literal LiteralUtil::CreateR4FromArray4D( const Array4D& values) { return CreateFromArray(values); } template -/* static */ std::unique_ptr -LiteralUtil::CreateR4FromArray4DWithLayout(const Array4D& values, - const Layout& layout) { +/* static */ Literal LiteralUtil::CreateR4FromArray4DWithLayout( + const Array4D& values, const Layout& layout) { return CreateFromArrayWithLayout(values, layout); } // Returns an identity matrix (rank 2) with the given row and column count. template -/* static */ std::unique_ptr LiteralUtil::MakeIdentityR2(int64 size) { +/* static */ Literal LiteralUtil::MakeIdentityR2(int64 size) { Array2D array(size, size, 0); for (int64 i = 0; i < size; ++i) { array(i, i) = 1; @@ -569,46 +549,39 @@ template } template -/* static */ std::unique_ptr -LiteralUtil::CreateFullWithDescendingLayout( - tensorflow::gtl::ArraySlice dimensions, NativeT value) { - auto literal = - absl::make_unique(ShapeUtil::MakeShapeWithDescendingLayout( - primitive_util::NativeToPrimitiveType(), dimensions)); - literal->PopulateWithValue(value); +/* static */ Literal LiteralUtil::CreateFullWithDescendingLayout( + absl::Span dimensions, NativeT value) { + Literal literal(ShapeUtil::MakeShapeWithDescendingLayout( + primitive_util::NativeToPrimitiveType(), dimensions)); + literal.PopulateWithValue(value); return literal; } template -/* static */ StatusOr> -LiteralUtil::CreateRandomLiteral( +/* static */ StatusOr LiteralUtil::CreateRandomLiteral( const Shape& shape, - const std::function)>& generator) { + const std::function)>& generator) { using NativeT = typename primitive_util::PrimitiveTypeToNative::type; TF_RET_CHECK(shape.element_type() == type); - auto literal = absl::make_unique(shape); - TF_RETURN_IF_ERROR(literal.get()->Populate( - [&](tensorflow::gtl::ArraySlice indexes) { - return generator(indexes); - })); + Literal literal(shape); + TF_RETURN_IF_ERROR(literal.Populate( + [&](absl::Span indexes) { return generator(indexes); })); return std::move(literal); } template -/* static */ StatusOr> -LiteralUtil::CreateRandomLiteral(const Shape& shape, E* engine, T mean, - T stddev) { +/* static */ StatusOr LiteralUtil::CreateRandomLiteral( + const Shape& shape, E* engine, T mean, T stddev) { using NativeT = typename primitive_util::PrimitiveTypeToNative::type; std::normal_distribution generator(mean, stddev); return CreateRandomLiteral( - shape, [&](tensorflow::gtl::ArraySlice /*indexes*/) { - return generator(*engine); - }); + shape, + [&](absl::Span /*indexes*/) { return generator(*engine); }); } template -/* static */ StatusOr> -LiteralUtil::CreateRandomLiteral(const Shape& shape, T mean, T stddev) { +/* static */ StatusOr LiteralUtil::CreateRandomLiteral( + const Shape& shape, T mean, T stddev) { std::minstd_rand0 engine; return CreateRandomLiteral(shape, &engine, mean, stddev); } diff --git a/tensorflow/compiler/xla/map_util.h b/tensorflow/compiler/xla/map_util.h index 3c74e070da529b7f1431e01fbaf31932f582db44..fcff48b6b18ba115a67f3141a9aea4ca461be55d 100644 --- a/tensorflow/compiler/xla/map_util.h +++ b/tensorflow/compiler/xla/map_util.h @@ -60,7 +60,7 @@ MaybeFind(const Collection& collection, if (it == collection.end()) { std::ostringstream os; os << key; - return NotFound("key not found: %s", os.str().c_str()); + return NotFound("key not found: %s", os.str()); } return {it->second}; } diff --git a/tensorflow/compiler/xla/metric_table_report.cc b/tensorflow/compiler/xla/metric_table_report.cc index 2f22e02c3edc1979d91efdb4b9c8697e5301a47f..4eab4fa4290c270697c00be20840cf4e85459183 100644 --- a/tensorflow/compiler/xla/metric_table_report.cc +++ b/tensorflow/compiler/xla/metric_table_report.cc @@ -19,7 +19,7 @@ limitations under the License. #include #include "absl/strings/str_cat.h" -#include "tensorflow/core/lib/strings/stringprintf.h" +#include "absl/strings/str_format.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -264,8 +264,7 @@ string MetricTableReport::MetricString(double metric) { } string MetricTableReport::MetricPercent(double metric) { - return tensorflow::strings::Printf("%5.2f%%", - metric / expected_metric_sum_ * 100.0); + return absl::StrFormat("%5.2f%%", metric / expected_metric_sum_ * 100.0); } } // namespace xla diff --git a/tensorflow/compiler/xla/packed_literal_reader.cc b/tensorflow/compiler/xla/packed_literal_reader.cc index 012df875519c5ddec498507a56da40253e5e1da6..0f86f9f35e105713aa3072a9ebf572d33d35d66d 100644 --- a/tensorflow/compiler/xla/packed_literal_reader.cc +++ b/tensorflow/compiler/xla/packed_literal_reader.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/base/casts.h" #include "absl/memory/memory.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" @@ -27,7 +28,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/core/casts.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/types.h" @@ -39,8 +39,8 @@ PackedLiteralReader::PackedLiteralReader(tensorflow::RandomAccessFile* file) PackedLiteralReader::~PackedLiteralReader() { delete file_; } -StatusOr> PackedLiteralReader::Read( - const Shape& shape, const Layout* layout) { +StatusOr PackedLiteralReader::Read(const Shape& shape, + const Layout* layout) { VLOG(3) << "reading shape from file: " << ShapeUtil::HumanString(shape) << " layout: " << (layout == nullptr ? "" : layout->ShortDebugString()); @@ -54,17 +54,17 @@ StatusOr> PackedLiteralReader::Read( if (shape.element_type() != F32) { return Unimplemented( "not yet implemented element type for packed literal reading: %s", - PrimitiveType_Name(shape.element_type()).c_str()); + PrimitiveType_Name(shape.element_type())); } - auto result = absl::make_unique(literal_shape); - result->PopulateWithValue(std::numeric_limits::quiet_NaN()); + Literal result(literal_shape); + result.PopulateWithValue(std::numeric_limits::quiet_NaN()); int64 elements = ShapeUtil::ElementsIn(shape); - tensorflow::gtl::ArraySlice field = result->data(); - char* data = tensorflow::bit_cast(field.data()); + absl::Span field = result.data(); + char* data = absl::bit_cast(field.data()); uint64 bytes = elements * sizeof(float); - tensorflow::StringPiece sp; // non-absl OK + absl::string_view sp; auto s = file_->Read(offset_, bytes, &sp, data); offset_ += sp.size(); if (!s.ok()) { @@ -85,7 +85,7 @@ bool PackedLiteralReader::IsExhausted() const { // Try to read a single byte from offset_. If we can't, we've // exhausted the data. char single_byte[1]; - tensorflow::StringPiece sp; // non-absl OK + absl::string_view sp; auto s = file_->Read(offset_, sizeof(single_byte), &sp, single_byte); return !s.ok(); } diff --git a/tensorflow/compiler/xla/packed_literal_reader.h b/tensorflow/compiler/xla/packed_literal_reader.h index 98dccaa9a246520bf60217b96d67a13a24c34b4a..d6d2ff1521bab341b166c4f5c1dc0917e28573d8 100644 --- a/tensorflow/compiler/xla/packed_literal_reader.h +++ b/tensorflow/compiler/xla/packed_literal_reader.h @@ -41,8 +41,7 @@ class PackedLiteralReader { // // Layout is optional. If it is not provided, no layout is set on the literal // that is produced. - StatusOr> Read(const Shape& shape, - const Layout* layout = nullptr); + StatusOr Read(const Shape& shape, const Layout* layout = nullptr); // Returns whether the input file has been fully exhausted; i.e. all available // packed literals have been read and we're at the end of the file. diff --git a/tensorflow/compiler/xla/protobuf_util.cc b/tensorflow/compiler/xla/protobuf_util.cc index 787725e884c810fd724ab88ad7d4beaf3e0a6cc7..b507a2ef79f1d7e9ae632744675dddf574490805 100644 --- a/tensorflow/compiler/xla/protobuf_util.cc +++ b/tensorflow/compiler/xla/protobuf_util.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/protobuf.h" namespace xla { @@ -49,16 +50,40 @@ string SanitizeFilename(const string& file_name) { return safe_file_name; } +std::pair>*> +GetDirectoryExpanders() { + static auto* mutex = new tensorflow::mutex; + static auto* singleton = new std::vector>; + return {mutex, singleton}; +} + +// Runs all the directory expanders over x and returns the result. +string Expand(string x) { + auto pair = GetDirectoryExpanders(); + tensorflow::mutex_lock lock(*pair.first); + for (const auto& f : *pair.second) { + x = f(x); + } + return x; +} + } // namespace Status DumpProtoToDirectory(const tensorflow::protobuf::Message& message, const string& directory, const string& file_name) { tensorflow::Env* env = tensorflow::Env::Default(); - TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(directory)); + string expanded_dir = Expand(directory); + TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(expanded_dir)); string safe_file_name = SanitizeFileName(file_name) + ".pb"; - const string path = tensorflow::io::JoinPath(directory, safe_file_name); + const string path = tensorflow::io::JoinPath(expanded_dir, safe_file_name); return tensorflow::WriteBinaryProto(env, path, message); } +void RegisterDirectoryExpander(const std::function& expander) { + auto pair = GetDirectoryExpanders(); + tensorflow::mutex_lock lock(*pair.first); + pair.second->push_back(expander); +} + } // namespace protobuf_util } // namespace xla diff --git a/tensorflow/compiler/xla/protobuf_util.h b/tensorflow/compiler/xla/protobuf_util.h index 3667621367c7639c40ff17aee7b77305d4d34e33..f22fc8b8499dd4a5329276040331a2ed9e89bea9 100644 --- a/tensorflow/compiler/xla/protobuf_util.h +++ b/tensorflow/compiler/xla/protobuf_util.h @@ -39,6 +39,10 @@ extern bool ProtobufEquals(const tensorflow::protobuf::Message& m1, Status DumpProtoToDirectory(const tensorflow::protobuf::Message& message, const string& directory, const string& file_name); +// Registers a function that may either expand a dirpath or forward the original +// dirpath along as-is. +void RegisterDirectoryExpander(const std::function& expander); + } // namespace protobuf_util } // namespace xla diff --git a/tensorflow/compiler/xla/python/BUILD b/tensorflow/compiler/xla/python/BUILD index 2d8fe434b0d774615f94fe5d111390a9a756eb94..f0d84646b9f01ad3ad209073f13b7b3ec21635d1 100644 --- a/tensorflow/compiler/xla/python/BUILD +++ b/tensorflow/compiler/xla/python/BUILD @@ -40,6 +40,8 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/python:numpy_lib", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", ], ) @@ -61,6 +63,7 @@ cc_library( "//tensorflow/core:framework_lite", "//tensorflow/core:lib", "@com_google_absl//absl/memory", + "@com_google_absl//absl/types:span", ], ) diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc index 00e36c3c86a8b46b8479ac8245405459c3cfdd81..92df404b8ec0aed4899906877a4dd41102bdf7a0 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.cc +++ b/tensorflow/compiler/xla/python/local_computation_builder.cc @@ -81,8 +81,8 @@ Status TransferToInfeedLocalReplica(const Literal& literal, return client->TransferToInfeedLocal(literal, device_ordinal); } -StatusOr> TransferFromOutfeedLocalReplica( - const Shape& shape, int replica_number) { +StatusOr TransferFromOutfeedLocalReplica(const Shape& shape, + int replica_number) { VLOG(1) << "Outfeeding literal from replica number: " << replica_number << " shape: " << shape; LocalClient* client = GetOrCreateLocalClient(); @@ -141,9 +141,8 @@ StatusOr LocalShapedBuffer::FromLiteral( LocalClient* client = GetOrCreateLocalClient(); StatusOr buf = [&] { if (shape_with_layout) { - std::unique_ptr relaid = - argument.Relayout(shape_with_layout.value()); - return ToBuffer(client, /*device_ordinal=*/0, *relaid); + Literal relaid = argument.Relayout(shape_with_layout.value()); + return ToBuffer(client, /*device_ordinal=*/0, relaid); } return ToBuffer(client, /*device_ordinal=*/0, argument); }(); @@ -151,7 +150,7 @@ StatusOr LocalShapedBuffer::FromLiteral( return new LocalShapedBuffer(std::move(buf).ValueOrDie()); } -StatusOr> LocalShapedBuffer::ToLiteral() const { +StatusOr LocalShapedBuffer::ToLiteral() const { LocalClient* client = GetOrCreateLocalClient(); return client->ShapedBufferToLiteral(*shaped_buffer()); } @@ -160,7 +159,7 @@ CompiledLocalComputation::CompiledLocalComputation( std::unique_ptr executable) : executable_(std::move(executable)) {} -StatusOr> CompiledLocalComputation::Execute( +StatusOr CompiledLocalComputation::Execute( const std::vector& arguments, const std::vector>& shapes_with_layout) { LocalClient* client = GetOrCreateLocalClient(); @@ -169,7 +168,7 @@ StatusOr> CompiledLocalComputation::Execute( // Each replica populates a StatusOr result, but only replica zero actually // retrieves its literal value. - std::vector>> results(GetReplicaCount()); + std::vector> results(GetReplicaCount()); { tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), "xlarun", GetReplicaCount()); @@ -198,9 +197,8 @@ StatusOr> CompiledLocalComputation::Execute( StatusOr pushed; if (shape_with_layout) { - std::unique_ptr relaid = - argument.Relayout(shape_with_layout.value()); - pushed = ToBuffer(client, device_ordinal, *relaid); + Literal relaid = argument.Relayout(shape_with_layout.value()); + pushed = ToBuffer(client, device_ordinal, relaid); } else { pushed = ToBuffer(client, device_ordinal, argument); } @@ -251,7 +249,7 @@ StatusOr> CompiledLocalComputation::Execute( return InternalError( "Failed running replica %d (other replicas may have failed as well): " "%s.", - replica, statusor.status().ToString().c_str()); + replica, statusor.status().ToString()); } } @@ -259,7 +257,7 @@ StatusOr> CompiledLocalComputation::Execute( } LocalShapedBuffer* CompiledLocalComputation::ExecuteWithShapedBuffers( - tensorflow::gtl::ArraySlice argument_handles) { + absl::Span argument_handles) { LocalClient* client = GetOrCreateLocalClient(); std::vector argument_buffers; @@ -369,8 +367,7 @@ LocalOp LocalComputationBuilder::ConstantLiteral(const Literal& literal) { } LocalOp LocalComputationBuilder::Broadcast( - const LocalOp& operand, - tensorflow::gtl::ArraySlice broadcast_sizes) { + const LocalOp& operand, absl::Span broadcast_sizes) { return xla::Broadcast(operand.op(), broadcast_sizes); } @@ -380,14 +377,14 @@ LocalOp LocalComputationBuilder::Pad(const LocalOp& operand, return xla::Pad(operand.op(), padding_value.op(), padding_config); } -LocalOp LocalComputationBuilder::Reshape( - const LocalOp& operand, tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice new_sizes) { +LocalOp LocalComputationBuilder::Reshape(const LocalOp& operand, + absl::Span dimensions, + absl::Span new_sizes) { return xla::Reshape(operand.op(), dimensions, new_sizes); } -LocalOp LocalComputationBuilder::Collapse( - const LocalOp& operand, tensorflow::gtl::ArraySlice dimensions) { +LocalOp LocalComputationBuilder::Collapse(const LocalOp& operand, + absl::Span dimensions) { return xla::Collapse(operand.op(), dimensions); } @@ -395,10 +392,10 @@ LocalOp LocalComputationBuilder::CrossReplicaSum(const LocalOp& operand) { return xla::CrossReplicaSum(operand.op()); } -LocalOp LocalComputationBuilder::Slice( - const LocalOp& operand, tensorflow::gtl::ArraySlice start_indices, - tensorflow::gtl::ArraySlice limit_indices, - tensorflow::gtl::ArraySlice strides) { +LocalOp LocalComputationBuilder::Slice(const LocalOp& operand, + absl::Span start_indices, + absl::Span limit_indices, + absl::Span strides) { return xla::Slice(operand.op(), start_indices, limit_indices, strides); } @@ -411,7 +408,7 @@ LocalOp LocalComputationBuilder::SliceInDim(const LocalOp& operand, LocalOp LocalComputationBuilder::DynamicSlice( const LocalOp& operand, const LocalOp& start_indices, - tensorflow::gtl::ArraySlice slice_sizes) { + absl::Span slice_sizes) { return xla::DynamicSlice(operand.op(), start_indices.op(), slice_sizes); } @@ -421,8 +418,8 @@ LocalOp LocalComputationBuilder::DynamicUpdateSlice( return xla::DynamicUpdateSlice(operand.op(), update.op(), start_indices.op()); } -LocalOp LocalComputationBuilder::ConcatInDim( - tensorflow::gtl::ArraySlice operands, int64 dimension) { +LocalOp LocalComputationBuilder::ConcatInDim(absl::Span operands, + int64 dimension) { std::vector xla_ops; xla_ops.reserve(operands.size()); for (const auto& op : operands) { @@ -433,18 +430,16 @@ LocalOp LocalComputationBuilder::ConcatInDim( LocalOp LocalComputationBuilder::SelectAndScatterWithGeneralPadding( const LocalOp& operand, const LocalComputation& select, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - const LocalOp& source, const LocalOp& init_value, - const LocalComputation& scatter) { + absl::Span window_dimensions, + absl::Span window_strides, + absl::Span> padding, const LocalOp& source, + const LocalOp& init_value, const LocalComputation& scatter) { return xla::SelectAndScatterWithGeneralPadding( operand.op(), select.computation(), window_dimensions, window_strides, padding, source.op(), init_value.op(), scatter.computation()); } -LocalOp LocalComputationBuilder::Tuple( - tensorflow::gtl::ArraySlice elements) { +LocalOp LocalComputationBuilder::Tuple(absl::Span elements) { std::vector xla_ops; xla_ops.reserve(elements.size()); for (const auto& op : elements) { @@ -471,13 +466,14 @@ LocalOp LocalComputationBuilder::DotGeneral( LocalOp LocalComputationBuilder::ConvGeneralDilated( const LocalOp& lhs, const LocalOp& rhs, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - tensorflow::gtl::ArraySlice lhs_dilation, - tensorflow::gtl::ArraySlice rhs_dilation, - const ConvolutionDimensionNumbers& dimension_numbers) { + absl::Span window_strides, + absl::Span> padding, + absl::Span lhs_dilation, absl::Span rhs_dilation, + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count) { return xla::ConvGeneralDilated(lhs.op(), rhs.op(), window_strides, padding, - lhs_dilation, rhs_dilation, dimension_numbers); + lhs_dilation, rhs_dilation, dimension_numbers, + feature_group_count); } LocalOp LocalComputationBuilder::ConvertElementType( @@ -490,9 +486,8 @@ LocalOp LocalComputationBuilder::BitcastConvertType( return xla::BitcastConvertType(operand.op(), new_element_type); } -LocalOp LocalComputationBuilder::Call( - const LocalComputation& local_computation, - tensorflow::gtl::ArraySlice operands) { +LocalOp LocalComputationBuilder::Call(const LocalComputation& local_computation, + absl::Span operands) { std::vector xla_ops; xla_ops.reserve(operands.size()); for (const auto& op : operands) { @@ -502,19 +497,18 @@ LocalOp LocalComputationBuilder::Call( } LocalOp LocalComputationBuilder::Transpose( - const LocalOp& operand, tensorflow::gtl::ArraySlice permutation) { + const LocalOp& operand, absl::Span permutation) { return xla::Transpose(operand.op(), permutation); } -LocalOp LocalComputationBuilder::Rev( - const LocalOp& operand, tensorflow::gtl::ArraySlice dimensions) { +LocalOp LocalComputationBuilder::Rev(const LocalOp& operand, + absl::Span dimensions) { return xla::Rev(operand.op(), dimensions); } -LocalOp LocalComputationBuilder::Map( - tensorflow::gtl::ArraySlice operands, - const LocalComputation& local_computation, - tensorflow::gtl::ArraySlice dimensions) { +LocalOp LocalComputationBuilder::Map(absl::Span operands, + const LocalComputation& local_computation, + absl::Span dimensions) { std::vector xla_ops; xla_ops.reserve(operands.size()); for (const auto& op : operands) { @@ -528,7 +522,7 @@ LocalOp LocalComputationBuilder::Map( LocalOp LocalComputationBuilder::Reduce( const LocalOp& operand, const LocalOp& init_value, const LocalComputation& local_computation, - tensorflow::gtl::ArraySlice dimensions_to_reduce) { + absl::Span dimensions_to_reduce) { return xla::Reduce(operand.op(), init_value.op(), local_computation.computation(), dimensions_to_reduce); } @@ -536,12 +530,15 @@ LocalOp LocalComputationBuilder::Reduce( LocalOp LocalComputationBuilder::ReduceWindowWithGeneralPadding( const LocalOp& operand, const LocalOp& init_value, const LocalComputation& local_computation, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding) { + absl::Span window_dimensions, + absl::Span window_strides, + absl::Span base_dilations, + absl::Span window_dilations, + absl::Span> padding) { return xla::ReduceWindowWithGeneralPadding( operand.op(), init_value.op(), local_computation.computation(), - window_dimensions, window_strides, padding); + window_dimensions, window_strides, base_dilations, window_dilations, + padding); } LocalOp LocalComputationBuilder::RngNormal(const LocalOp& mu, @@ -575,13 +572,13 @@ StatusOr LocalComputationBuilder::IsConstant(const LocalOp& operand) { } LocalOp LocalComputationBuilder::Sort(const LocalOp& operand, int64 dimension) { - return xla::Sort(operand.op(), absl::nullopt, dimension); + return xla::Sort(operand.op(), {}, dimension); } LocalOp LocalComputationBuilder::SortKeyVal(const LocalOp& keys, const LocalOp& values, int64 dimension) { - return xla::Sort(keys.op(), values.op(), dimension); + return xla::Sort(keys.op(), {values.op()}, dimension); } StatusOr LocalComputationBuilder::BuildConstantSubGraph( @@ -599,10 +596,10 @@ StatusOr LocalComputationBuilder::BuildConstantSubGraph( #define _FORWARD_UNOP(method_name) \ _FORWARD(method_name, LocalOp, (const LocalOp& operand), (operand.op())) -#define _FORWARD_BINOP(method_name) \ - _FORWARD(method_name, LocalOp, \ - (const LocalOp& lhs, const LocalOp& rhs, \ - tensorflow::gtl::ArraySlice broadcast_dimensions), \ +#define _FORWARD_BINOP(method_name) \ + _FORWARD(method_name, LocalOp, \ + (const LocalOp& lhs, const LocalOp& rhs, \ + absl::Span broadcast_dimensions), \ (lhs.op(), rhs.op(), broadcast_dimensions)) #define _FORWARD_TRIOP(method_name) \ @@ -696,8 +693,7 @@ StatusOr DestructureLocalShapedBufferTuple( "Attemped to destructure a LocalShapedBuffer that did not have a tuple " "shape; shape: %s", ShapeUtil::HumanString( - local_shaped_buffer->shaped_buffer()->on_device_shape()) - .c_str()); + local_shaped_buffer->shaped_buffer()->on_device_shape())); } DeviceMemoryAllocator* allocator = diff --git a/tensorflow/compiler/xla/python/local_computation_builder.h b/tensorflow/compiler/xla/python/local_computation_builder.h index d9543b958dc40e092221b0276e2b1317bbcf499f..43332e0abd410c08dc5a40f7de39dbc96d34a72c 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.h +++ b/tensorflow/compiler/xla/python/local_computation_builder.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_COMPUTATION_BUILDER_H_ #define TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_COMPUTATION_BUILDER_H_ +#include "absl/types/span.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/executable_build_options.h" #include "tensorflow/compiler/xla/client/local_client.h" @@ -23,7 +24,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" namespace xla { namespace swig { @@ -51,8 +51,8 @@ Status TransferToInfeedLocalReplica(const Literal& literal, int replica_number); // Transfers a literal of the given shape from the outfeed of the given replica. // // The replica number is resolved to an appropriate device ordinal. -StatusOr > TransferFromOutfeedLocalReplica( - const Shape& shape, int replica_number); +StatusOr TransferFromOutfeedLocalReplica(const Shape& shape, + int replica_number); // Wraps a ScopedShapedBuffer produced by copying a literal "to // device," i.e. copying a literal to a scoped buffer via the local @@ -65,7 +65,7 @@ class LocalShapedBuffer { LocalShapedBuffer(ScopedShapedBuffer shaped_buffer); const ScopedShapedBuffer* shaped_buffer() const; - StatusOr > ToLiteral() const; + StatusOr ToLiteral() const; // Transfers ownership of the encapsulated ShapedBuffer to the caller, // analogous to std::unique_ptr::release(). @@ -117,12 +117,12 @@ class CompiledLocalComputation { // with optionally-specified argument layouts. The literals will be // re-laid out according to the corresponding elements of // shapes_with_layout. - StatusOr > Execute( + StatusOr Execute( const std::vector& arguments, const std::vector >& shapes_with_layout); LocalShapedBuffer* ExecuteWithShapedBuffers( - tensorflow::gtl::ArraySlice argument_handles); + absl::Span argument_handles); private: std::unique_ptr executable_; @@ -199,46 +199,41 @@ class LocalComputationBuilder { LocalOp ConstantLiteral(const Literal& literal); LocalOp Broadcast(const LocalOp& operand, - tensorflow::gtl::ArraySlice broadcast_sizes); + absl::Span broadcast_sizes); LocalOp Pad(const LocalOp& operand, const LocalOp& padding_value, const PaddingConfig& padding_config); - LocalOp Reshape(const LocalOp& operand, - tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice new_sizes); + LocalOp Reshape(const LocalOp& operand, absl::Span dimensions, + absl::Span new_sizes); - LocalOp Collapse(const LocalOp& operand, - tensorflow::gtl::ArraySlice dimensions); + LocalOp Collapse(const LocalOp& operand, absl::Span dimensions); LocalOp CrossReplicaSum(const LocalOp& operand); - LocalOp Slice(const LocalOp& operand, - tensorflow::gtl::ArraySlice start_indices, - tensorflow::gtl::ArraySlice limit_indices, - tensorflow::gtl::ArraySlice strides); + LocalOp Slice(const LocalOp& operand, absl::Span start_indices, + absl::Span limit_indices, + absl::Span strides); LocalOp SliceInDim(const LocalOp& operand, int64 start_index, int64 limit_index, int64 stride, int64 dimno); LocalOp DynamicSlice(const LocalOp& operand, const LocalOp& start_indices, - tensorflow::gtl::ArraySlice slice_sizes); + absl::Span slice_sizes); LocalOp DynamicUpdateSlice(const LocalOp& operand, const LocalOp& update, const LocalOp& start_indices); - LocalOp ConcatInDim(tensorflow::gtl::ArraySlice operands, - int64 dimension); + LocalOp ConcatInDim(absl::Span operands, int64 dimension); LocalOp SelectAndScatterWithGeneralPadding( const LocalOp& operand, const LocalComputation& select, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice > padding, - const LocalOp& source, const LocalOp& init_value, - const LocalComputation& scatter); + absl::Span window_dimensions, + absl::Span window_strides, + absl::Span > padding, const LocalOp& source, + const LocalOp& init_value, const LocalComputation& scatter); - LocalOp Tuple(tensorflow::gtl::ArraySlice elements); + LocalOp Tuple(absl::Span elements); LocalOp GetTupleElement(const LocalOp& tuple_data, int64 index); @@ -249,11 +244,12 @@ class LocalComputationBuilder { LocalOp ConvGeneralDilated( const LocalOp& lhs, const LocalOp& rhs, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice > padding, - tensorflow::gtl::ArraySlice lhs_dilation, - tensorflow::gtl::ArraySlice rhs_dilation, - const ConvolutionDimensionNumbers& dimension_numbers); + absl::Span window_strides, + absl::Span > padding, + absl::Span lhs_dilation, + absl::Span rhs_dilation, + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count); LocalOp ConvertElementType(const LocalOp& operand, PrimitiveType new_element_type); @@ -262,28 +258,29 @@ class LocalComputationBuilder { PrimitiveType new_element_type); LocalOp Call(const LocalComputation& local_computation, - tensorflow::gtl::ArraySlice operands); + absl::Span operands); LocalOp Transpose(const LocalOp& operand, - tensorflow::gtl::ArraySlice permutation); + absl::Span permutation); - LocalOp Rev(const LocalOp& operand, - tensorflow::gtl::ArraySlice dimensions); + LocalOp Rev(const LocalOp& operand, absl::Span dimensions); - LocalOp Map(tensorflow::gtl::ArraySlice operands, + LocalOp Map(absl::Span operands, const LocalComputation& local_computation, - tensorflow::gtl::ArraySlice dimensions); + absl::Span dimensions); LocalOp Reduce(const LocalOp& operand, const LocalOp& init_value, const LocalComputation& local_computation, - tensorflow::gtl::ArraySlice dimensions_to_reduce); + absl::Span dimensions_to_reduce); LocalOp ReduceWindowWithGeneralPadding( const LocalOp& operand, const LocalOp& init_value, const LocalComputation& local_computation, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice > padding); + absl::Span window_dimensions, + absl::Span window_strides, + absl::Span base_dilations, + absl::Span window_dilations, + absl::Span > padding); LocalOp RngNormal(const LocalOp& mu, const LocalOp& sigma, const Shape& shape); @@ -316,7 +313,7 @@ class LocalComputationBuilder { #define _FORWARD_BINOP(method_name) \ _FORWARD(method_name, LocalOp, \ (const LocalOp& lhs, const LocalOp& rhs, \ - tensorflow::gtl::ArraySlice broadcast_dimensions)) + absl::Span broadcast_dimensions)) #define _FORWARD_TRIOP(method_name) \ _FORWARD(method_name, LocalOp, \ diff --git a/tensorflow/compiler/xla/python/local_computation_builder.i b/tensorflow/compiler/xla/python/local_computation_builder.i index 08dccb3ee18606965b39bbcb79a89a0478afa790..521490e76c138553c5cc6895412eadb35a939881 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.i +++ b/tensorflow/compiler/xla/python/local_computation_builder.i @@ -22,15 +22,15 @@ limitations under the License. // // C++ Python // -------------------------------------+--------------------------------------- -// ArraySlice <- sequence of int -// ArraySlice <- sequence of LocalOp +// Span <- sequence of int +// Span <- sequence of LocalOp // Literal <-> (nested tuple of) numpy ndarray // std::vector <- sequence of (nested tuple of) ndarray // Shape -> pair holding (dtype, dimensions) // <- object duck-typed as xla_client.Shape // std::vector <- sequence of xla_client.Shape objects // PrimitiveType <- int -// ArraySlice> <- sequence of int pairs +// Span> <- sequence of int pairs // PaddingConfig proto <- corresponding Python proto // ConvolutionDimensionNumbers proto <- corresponding Python proto // DotDimensionNumbers proto <- corresponding Python proto @@ -109,11 +109,12 @@ limitations under the License. // Must be included first #include "tensorflow/python/lib/core/numpy.h" -#include "third_party/absl/strings/str_cat.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/python/numpy_bridge.h" #include "tensorflow/compiler/xla/python/local_computation_builder.h" @@ -155,8 +156,8 @@ bool HandleStringAttribute(PyObject* o, return true; // The attribute is None, which we consider ok. } if (!PyString_Check(attr)) { - string message = tensorflow::strings::Printf("%s must be a string or none; got %s", - attr_name, numpy::PyObjectCppRepr(attr).c_str()); + string message = absl::StrFormat("%s must be a string or none; got %s", + attr_name, numpy::PyObjectCppRepr(attr)); PyErr_SetString(PyExc_TypeError, message.c_str()); Py_DECREF(attr); return false; // Type error, not ok. @@ -215,9 +216,9 @@ tensorflow::ImportNumpy(); } -%typemap(out) StatusOr< std::unique_ptr > { +%typemap(out) StatusOr { if ($1.ok()) { - std::unique_ptr value = $1.ConsumeValueOrDie(); + Literal value = $1.ConsumeValueOrDie(); $result = numpy::PyObjectFromXlaLiteral(*value); } else { PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); @@ -266,9 +267,9 @@ tensorflow::ImportNumpy(); $result = Py_None; } -// ArraySlice +// Span -%typemap(in) tensorflow::gtl::ArraySlice +%typemap(in) absl::Span (std::vector temps) { if (!PySequence_Check($input)) { PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); @@ -298,9 +299,9 @@ tensorflow::ImportNumpy(); $1 = temps; } -// ArraySlice +// Span -%typemap(in) tensorflow::gtl::ArraySlice( +%typemap(in) absl::Span( std::vector temps) { if (!PySequence_Check($input)) { PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); @@ -322,7 +323,7 @@ tensorflow::ImportNumpy(); // LocalShapedBuffer* -%typemap(in) tensorflow::gtl::ArraySlice +%typemap(in) absl::Span (std::vector temps) { if (!PySequence_Check($input)) { PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); @@ -345,25 +346,25 @@ tensorflow::ImportNumpy(); // Literal -%typemap(in) const Literal& (StatusOr< std::unique_ptr > literal_status) { +%typemap(in) const Literal& (StatusOr literal_status) { literal_status = numpy::XlaLiteralFromPyObject($input); if (!literal_status.ok()) { PyErr_SetString(PyExc_RuntimeError, literal_status.status().ToString().c_str()); SWIG_fail; } - $1 = literal_status.ValueOrDie().get(); + $1 = &literal_status.ValueOrDie(); } -%typemap(out) std::unique_ptr { +%typemap(out) Literal { $result = numpy::PyObjectFromXlaLiteral(*$1); } -%typemap(out) StatusOr< std::unique_ptr > { +%typemap(out) StatusOr { if (!$1.ok()) { PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); SWIG_fail; } - $result = numpy::PyObjectFromXlaLiteral(*$1.ValueOrDie()); + $result = numpy::PyObjectFromXlaLiteral($1.ValueOrDie()); } %typemap(in) const std::vector& (std::vector temps) { @@ -374,13 +375,13 @@ tensorflow::ImportNumpy(); const int size = PySequence_Size($input); for (int i = 0; i < size; ++i) { PyObject* o = PySequence_GetItem($input, i); - StatusOr< std::unique_ptr > literal_status = numpy::XlaLiteralFromPyObject(o); + StatusOr literal_status = numpy::XlaLiteralFromPyObject(o); if (!literal_status.ok()) { PyErr_SetString(PyExc_RuntimeError, literal_status.status().ToString().c_str()); Py_DECREF(o); SWIG_fail; } - temps.push_back(std::move(*literal_status.ConsumeValueOrDie())); + temps.push_back(literal_status.ConsumeValueOrDie()); Py_DECREF(o); } $1 = &temps; @@ -495,9 +496,9 @@ tensorflow::ImportNumpy(); $1 = static_cast(value); } -// ArraySlice> +// Span> -%typemap(in) tensorflow::gtl::ArraySlice > +%typemap(in) absl::Span > (std::vector > temps) { if (!PySequence_Check($input)) { PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); diff --git a/tensorflow/compiler/xla/python/numpy_bridge.cc b/tensorflow/compiler/xla/python/numpy_bridge.cc index f2f99c1745900fdb4ca5fc8b14d65c67de1dc135..b0aa024c7474cf8e6934432b2f364be464714999 100644 --- a/tensorflow/compiler/xla/python/numpy_bridge.cc +++ b/tensorflow/compiler/xla/python/numpy_bridge.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/python/numpy_bridge.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/core/platform/logging.h" @@ -150,9 +151,7 @@ static int NumpyTypenum(PyObject* o) { // // NOTE: this is an internal helper for conversion to a C++, and so decrefs r. static string ExtractStringAndDecref(PyObject* r) { - auto error = [r] { - return tensorflow::strings::Printf("", r); - }; + auto error = [r] { return absl::StrFormat("", r); }; if (r == nullptr) { return error(); } @@ -369,10 +368,10 @@ PyObject* PyObjectFromXlaLiteral(const LiteralSlice& literal) { } } -StatusOr> XlaLiteralFromPyObject(PyObject* o) { +StatusOr XlaLiteralFromPyObject(PyObject* o) { if (PyTuple_Check(o)) { int num_elements = PyTuple_Size(o); - std::vector> elements; + std::vector elements; elements.reserve(num_elements); for (int i = 0; i < num_elements; i++) { PyObject* element = PyTuple_GetItem(o, i); @@ -390,8 +389,7 @@ StatusOr> XlaLiteralFromPyObject(PyObject* o) { int np_type = PyArray_TYPE(py_array); auto literal = LiteralUtil::CreateFromDimensions( NumpyTypeToPrimitiveType(np_type), dimensions); - TF_RETURN_IF_ERROR( - CopyNumpyArrayToLiteral(np_type, py_array, literal.get())); + TF_RETURN_IF_ERROR(CopyNumpyArrayToLiteral(np_type, py_array, &literal)); return std::move(literal); } else { return InvalidArgument( diff --git a/tensorflow/compiler/xla/python/numpy_bridge.h b/tensorflow/compiler/xla/python/numpy_bridge.h index a67c93a4fb7413f9bbcb9afd92c36fd118836e1f..40ff2d9ad214cc4dcad42234fa296834cbc92882 100644 --- a/tensorflow/compiler/xla/python/numpy_bridge.h +++ b/tensorflow/compiler/xla/python/numpy_bridge.h @@ -25,9 +25,9 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/python/lib/core/numpy.h" namespace xla { @@ -82,7 +82,7 @@ PyObject* PyObjectFromXlaLiteral(const LiteralSlice& literal); // To avoid transferring ownership of the data buffers that underlie // PyArrays and XLA literals, this function makes deep copies of all // array data. -StatusOr > XlaLiteralFromPyObject(PyObject* o); +StatusOr XlaLiteralFromPyObject(PyObject* o); // The following functions copy array data from the buffers underlying Numpy // ndarrays into those underlying XLA literals, and vice versa. diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index fa4366ff0789a3d05c26479a746a18dfcf7e902b..f8197488fb3bacb312cc7fbf149b773851992b8a 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -995,7 +995,30 @@ class ComputationBuilder(object): window_strides) return self._client.ReduceWindowWithGeneralPadding( operand, init_value, computation_to_apply.c_local_computation, - window_dimensions, window_strides, pads) + window_dimensions, window_strides, (), (), pads) + + def ReduceWindowWithGeneralPadding( + self, operand, init_value, computation_to_apply, window_dimensions, + window_strides, base_dilations, window_dilations, padding): + """Enqueues a windowed reduction operation onto the computation. + + Args: + operand: reduction operand (LocalOp). + init_value: reduction initial value (LocalOp). + computation_to_apply: a binary reduction function (Computation). + window_dimensions: dimensions of window (sequence of integers). + window_strides: strides for window (sequence of integers). + base_dilations: dilations for the base (sequence of integers). + window_dilations: dilations for window (sequence of integers). + padding: length-N array-like of pairs of integers of (low, high) padding. + + Returns: + A LocalOp representing the added ReduceWindow op. + """ + return self._client.ReduceWindowWithGeneralPadding( + operand, init_value, computation_to_apply.c_local_computation, + window_dimensions, window_strides, base_dilations, window_dilations, + padding) def RngNormal(self, mu, sigma, dims): """Enqueues an RngNormal operation onto the computation. @@ -1109,7 +1132,7 @@ class ComputationBuilder(object): dimension_numbers = GetDotDimensionsFromLists(dimension_numbers) return self._client.DotGeneral(lhs, rhs, dimension_numbers) - def Conv(self, lhs, rhs, window_strides, padding): + def Conv(self, lhs, rhs, window_strides, padding, feature_group_count=1): """Enqueues a Conv operation onto the computation. Args: @@ -1117,6 +1140,7 @@ class ComputationBuilder(object): rhs: LocalOp for the rank N+2 array of kernel weights. window_strides: length-N array-like of integer kernel strides. padding: PaddingType representing either 'SAME' or 'VALID' padding. + feature_group_count: number of feature groups for grouped convolution. Returns: a LocalOp representing the Conv operation. """ @@ -1125,10 +1149,11 @@ class ComputationBuilder(object): self.GetShape(rhs).dimensions()[2:], window_strides) dimension_numbers = self._GetConvDimensionNumbers(len(window_strides)) return self._client.ConvGeneralDilated(lhs, rhs, window_strides, pads, (), - (), dimension_numbers) + (), dimension_numbers, + feature_group_count) def ConvWithGeneralPadding(self, lhs, rhs, window_strides, padding, - lhs_dilation, rhs_dilation): + lhs_dilation, rhs_dilation, feature_group_count=1): """Enqueues a ConvWithGeneralPadding operation onto the computation. Args: @@ -1138,6 +1163,7 @@ class ComputationBuilder(object): padding: length-N array-like of pairs of integers of (low, high) padding. lhs_dilation: length-N array-like of dilation factors. rhs_dilation: length-N array-like of dilation factors. + feature_group_count: number of feature groups for grouped convolution. Returns: A ComputationdataHandle representing the added ConvWithGeneralPadding op. @@ -1145,7 +1171,8 @@ class ComputationBuilder(object): dimension_numbers = self._GetConvDimensionNumbers(len(window_strides)) return self._client.ConvGeneralDilated(lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, - dimension_numbers) + dimension_numbers, + feature_group_count) def _GetConvDimensionNumbers(self, num_spatial_dims): """Create ConvolutionDimensionNumbers proto for convolutions.""" @@ -1163,7 +1190,8 @@ class ComputationBuilder(object): return dimension_numbers def ConvGeneralDilated(self, lhs, rhs, window_strides, padding, lhs_dilation, - rhs_dilation, dimension_numbers): + rhs_dilation, dimension_numbers, + feature_group_count=1): """Enqueues a ConvGeneralDilated operation onto the computation. Args: @@ -1190,6 +1218,7 @@ class ComputationBuilder(object): labels appear in the rhs_spec string, so that window_strides[0] is matched with the dimension corresponding to the first character appearing in rhs_spec that is not 'I' or 'O'. + feature_group_count: number of feature groups for grouped convolution. Returns: a LocalOp representing the ConvGenralDilated operation. """ @@ -1215,7 +1244,8 @@ class ComputationBuilder(object): key=lambda i: rhs_spec.index(out_spec[i]))) return self._client.ConvGeneralDilated(lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, - dimension_numbers) + dimension_numbers, + feature_group_count) def Sort(self, operand, dimension=-1): """Enqueues a sort operation onto the computation.""" diff --git a/tensorflow/compiler/xla/python/xla_client_test.py b/tensorflow/compiler/xla/python/xla_client_test.py index fd98e19457f61aade947aa354d2e415148d127f6..82103f03132e45ff822ce1ebcc2be47b24f5869f 100644 --- a/tensorflow/compiler/xla/python/xla_client_test.py +++ b/tensorflow/compiler/xla/python/xla_client_test.py @@ -661,6 +661,30 @@ class SingleOpTest(LocalComputationTest): [40., 50., 0.]]]]) self._ExecuteAndCompareClose(c, expected=np.transpose(result, (1, 3, 0, 2))) + def testConvGeneralDilatedGroupedConvolutionF32(self): + c = self._NewComputation() + a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32") + lhs = a(1, 2, 2, 3) + rhs = a(2, 1, 1, 2) * 10 + strides = [1, 1] + pads = [(1, 0), (0, 1)] + lhs_dilation = (2, 1) + rhs_dilation = (1, 1) + dimension_numbers = ("NCHW", "OIHW", "NCHW") + feature_group_count = 2 + c.ConvGeneralDilated(c.Constant(lhs), c.Constant(rhs), + strides, pads, lhs_dilation, rhs_dilation, + dimension_numbers, feature_group_count) + result = np.array([[[[0., 0., 0.], + [10., 20., 0.], + [0., 0., 0.], + [40., 50., 0.]], + [[0., 0., 0.], + [330., 380., 160.], + [0., 0., 0.], + [480., 530., 220.]]]]) + self._ExecuteAndCompareClose(c, expected=result) + def testBooleanNot(self): c = self._NewComputation() arr = NumpyArrayBool([True, False, True]) diff --git a/tensorflow/compiler/xla/reference_util.cc b/tensorflow/compiler/xla/reference_util.cc index 3de7ee2bc8c936680735102607436af77a17769c..ceb5e74db7c3b9305e9d77068df9ae0a3690af8a 100644 --- a/tensorflow/compiler/xla/reference_util.cc +++ b/tensorflow/compiler/xla/reference_util.cc @@ -108,17 +108,15 @@ ReferenceUtil::ConvArray3DGeneralDimensionsDilated( // array by adding a fourth dummy dimension of size 1 without stride, padding // and dilation. Array4D a4dlhs(lhs.n1(), lhs.n2(), lhs.n3(), 1); - a4dlhs.Each( - [&](tensorflow::gtl::ArraySlice indices, float* value_ptr) { - CHECK_EQ(indices[3], 0); - *value_ptr = lhs.operator()(indices[0], indices[1], indices[2]); - }); + a4dlhs.Each([&](absl::Span indices, float* value_ptr) { + CHECK_EQ(indices[3], 0); + *value_ptr = lhs.operator()(indices[0], indices[1], indices[2]); + }); Array4D a4drhs(rhs.n1(), rhs.n2(), rhs.n3(), 1); - a4drhs.Each( - [&](tensorflow::gtl::ArraySlice indices, float* value_ptr) { - CHECK_EQ(indices[3], 0); - *value_ptr = rhs.operator()(indices[0], indices[1], indices[2]); - }); + a4drhs.Each([&](absl::Span indices, float* value_ptr) { + CHECK_EQ(indices[3], 0); + *value_ptr = rhs.operator()(indices[0], indices[1], indices[2]); + }); // Add a second dummy spatial dimensions. ConvolutionDimensionNumbers dnums2d = dnums; dnums2d.add_input_spatial_dimensions(3); @@ -130,11 +128,10 @@ ReferenceUtil::ConvArray3DGeneralDimensionsDilated( auto convr3 = absl::make_unique>( convr4->planes(), convr4->depth(), convr4->height()); - convr4->Each( - [&](tensorflow::gtl::ArraySlice indices, float* value_ptr) { - CHECK_EQ(indices[3], 0); - convr3->operator()(indices[0], indices[1], indices[2]) = *value_ptr; - }); + convr4->Each([&](absl::Span indices, float* value_ptr) { + CHECK_EQ(indices[3], 0); + convr3->operator()(indices[0], indices[1], indices[2]) = *value_ptr; + }); return convr3; } @@ -189,11 +186,10 @@ ReferenceUtil::SeparableConvArray4D(const Array4D& input, /* static */ std::unique_ptr> ReferenceUtil::ReduceWindow1DGeneric( - const tensorflow::gtl::ArraySlice& operand, float init, + absl::Span operand, float init, const std::function& reduce_func, - const tensorflow::gtl::ArraySlice& window, - const tensorflow::gtl::ArraySlice& stride, - const tensorflow::gtl::ArraySlice>& padding) { + absl::Span window, absl::Span stride, + absl::Span> padding) { std::vector dim_lengths{static_cast(operand.size())}; std::vector window_counts(window.size(), 0); std::vector pad_low(window.size(), 0); @@ -221,10 +217,10 @@ ReferenceUtil::ReduceWindow1DGeneric( } /* static */ std::unique_ptr> -ReferenceUtil::ReduceWindow1DAdd( - const tensorflow::gtl::ArraySlice& operand, float init, - const tensorflow::gtl::ArraySlice& window, - const tensorflow::gtl::ArraySlice& stride, Padding padding) { +ReferenceUtil::ReduceWindow1DAdd(absl::Span operand, float init, + absl::Span window, + absl::Span stride, + Padding padding) { const auto add_reduce = [](float arg1, float arg2) { return arg1 + arg2; }; std::vector dim_lengths{static_cast(operand.size())}; return ReduceWindow1DGeneric( @@ -236,9 +232,8 @@ ReferenceUtil::ReduceWindow1DAdd( ReferenceUtil::ReduceWindow2DGeneric( const Array2D& operand, float init, const std::function& reduce_func, - const tensorflow::gtl::ArraySlice& window, - const tensorflow::gtl::ArraySlice& stride, - const tensorflow::gtl::ArraySlice>& padding) { + absl::Span window, absl::Span stride, + absl::Span> padding) { std::vector dim_lengths{operand.height(), operand.width()}; std::vector window_counts(window.size(), 0); @@ -275,9 +270,8 @@ ReferenceUtil::ReduceWindow2DGeneric( } /* static */ std::unique_ptr> ReferenceUtil::ReduceWindow2DAdd( - const Array2D& operand, float init, - const tensorflow::gtl::ArraySlice& window, - const tensorflow::gtl::ArraySlice& stride, Padding padding) { + const Array2D& operand, float init, absl::Span window, + absl::Span stride, Padding padding) { const auto add_reduce = [](float arg1, float arg2) { return arg1 + arg2; }; std::vector dim_lengths{operand.height(), operand.width()}; return ReduceWindow2DGeneric( @@ -286,9 +280,8 @@ ReferenceUtil::ReduceWindow2DGeneric( } /* static */ std::unique_ptr> ReferenceUtil::ReduceWindow3DAdd( - const Array3D& operand, float init, - const tensorflow::gtl::ArraySlice& window, - const tensorflow::gtl::ArraySlice& stride, Padding padding) { + const Array3D& operand, float init, absl::Span window, + absl::Span stride, Padding padding) { std::vector dim_lengths{operand.n1(), operand.n2(), operand.n3()}; auto padding_both = xla::MakePadding(dim_lengths, window, stride, padding); @@ -334,8 +327,8 @@ ReferenceUtil::ReduceWindow2DGeneric( ReferenceUtil::ReduceWindow4DGeneric( const Array4D& operand, float init, const std::function& reduce_func, - const tensorflow::gtl::ArraySlice& window, - const tensorflow::gtl::ArraySlice& stride, Padding padding) { + absl::Span window, absl::Span stride, + Padding padding) { std::vector dim_lengths{operand.n1(), operand.n2(), operand.n3(), operand.n4()}; return ReduceWindow4DGeneric( @@ -347,9 +340,8 @@ ReferenceUtil::ReduceWindow4DGeneric( ReferenceUtil::ReduceWindow4DGeneric( const Array4D& operand, float init, const std::function& reduce_func, - const tensorflow::gtl::ArraySlice& window, - const tensorflow::gtl::ArraySlice& stride, - const tensorflow::gtl::ArraySlice>& padding) { + absl::Span window, absl::Span stride, + absl::Span> padding) { std::vector dim_lengths{operand.n1(), operand.n2(), operand.n3(), operand.n4()}; @@ -401,9 +393,8 @@ ReferenceUtil::ReduceWindow4DGeneric( } /* static */ std::unique_ptr> ReferenceUtil::ReduceWindow4DAdd( - const Array4D& operand, float init, - const tensorflow::gtl::ArraySlice& window, - const tensorflow::gtl::ArraySlice& stride, Padding padding) { + const Array4D& operand, float init, absl::Span window, + absl::Span stride, Padding padding) { const auto add_reduce = [](float arg1, float arg2) { return arg1 + arg2; }; return ReduceWindow4DGeneric(operand, init, add_reduce, window, stride, padding); @@ -424,10 +415,12 @@ ReferenceUtil::ReduceWindow4DGeneric( } /* static */ std::unique_ptr> -ReferenceUtil::SelectAndScatter4DGePlus( - const Array4D& operand, const Array4D& source, float init, - const tensorflow::gtl::ArraySlice& window, - const tensorflow::gtl::ArraySlice& stride, bool same_padding) { +ReferenceUtil::SelectAndScatter4DGePlus(const Array4D& operand, + const Array4D& source, + float init, + absl::Span window, + absl::Span stride, + bool same_padding) { Padding padding = same_padding ? Padding::kSame : Padding::kValid; auto result = absl::make_unique>(operand.n1(), operand.n2(), operand.n3(), operand.n4()); @@ -529,13 +522,13 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated( } ordered_input_dimensions[0] = - lhs_literal->shape().dimensions(dnums.input_spatial_dimensions(0)); + lhs_literal.shape().dimensions(dnums.input_spatial_dimensions(0)); ordered_input_dimensions[1] = - lhs_literal->shape().dimensions(dnums.input_spatial_dimensions(1)); + lhs_literal.shape().dimensions(dnums.input_spatial_dimensions(1)); ordered_kernel_dimensions[0] = - rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(0)); + rhs_literal.shape().dimensions(dnums.kernel_spatial_dimensions(0)); ordered_kernel_dimensions[1] = - rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(1)); + rhs_literal.shape().dimensions(dnums.kernel_spatial_dimensions(1)); std::vector> paddings = MakePadding(ordered_input_dimensions, ordered_kernel_dimensions, @@ -546,7 +539,7 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated( WindowDimension dim; dim.set_size( - rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(0))); + rhs_literal.shape().dimensions(dnums.kernel_spatial_dimensions(0))); dim.set_stride(kernel_stride.first); dim.set_padding_low(paddings[0].first); dim.set_padding_high(paddings[0].second); @@ -556,7 +549,7 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated( WindowDimension dim2; dim2.set_size( - rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(1))); + rhs_literal.shape().dimensions(dnums.kernel_spatial_dimensions(1))); dim2.set_stride(kernel_stride.second); dim2.set_padding_low(paddings[1].first); dim2.set_padding_high(paddings[1].second); @@ -564,35 +557,39 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated( dim2.set_base_dilation(lhs_dilation.second); *window.add_dimensions() = dim2; - const Shape& shape = - ShapeInference::InferConvolveShape(lhs_literal->shape(), - rhs_literal->shape(), window, dnums) - .ConsumeValueOrDie(); + const Shape& shape = ShapeInference::InferConvolveShape( + lhs_literal.shape(), rhs_literal.shape(), + /*feature_group_count=*/1, window, dnums) + .ConsumeValueOrDie(); HloInstruction* lhs_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal))); HloInstruction* rhs_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal))); + PrecisionConfig precision_config; + precision_config.mutable_operand_precision()->Resize( + /*new_size=*/2, PrecisionConfig::DEFAULT); b.AddInstruction(HloInstruction::CreateConvolve( - shape, lhs_instruction, rhs_instruction, window, dnums)); + shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1, + window, dnums, precision_config)); HloModuleConfig config; HloModule module("ReferenceUtil", config); auto computation = module.AddEntryComputation(b.Build()); HloEvaluator evaluator; - std::unique_ptr result_literal = + Literal result_literal = evaluator.Evaluate(*computation, {}).ConsumeValueOrDie(); - CHECK_EQ(ShapeUtil::Rank(result_literal->shape()), 4); + CHECK_EQ(ShapeUtil::Rank(result_literal.shape()), 4); auto result = - absl::make_unique>(result_literal->shape().dimensions(0), - result_literal->shape().dimensions(1), - result_literal->shape().dimensions(2), - result_literal->shape().dimensions(3)); + absl::make_unique>(result_literal.shape().dimensions(0), + result_literal.shape().dimensions(1), + result_literal.shape().dimensions(2), + result_literal.shape().dimensions(3)); - result->Each([&](tensorflow::gtl::ArraySlice indices, float* value) { - *value = result_literal->Get(indices); + result->Each([&](absl::Span indices, float* value) { + *value = result_literal.Get(indices); }); return result; @@ -633,8 +630,7 @@ ReferenceUtil::ReduceToRowArray2D( } /*static*/ std::vector ReferenceUtil::Reduce4DTo1D( - const Array4D& array, float init, - tensorflow::gtl::ArraySlice dims, + const Array4D& array, float init, absl::Span dims, const std::function& reduce_function) { std::vector result; CHECK_EQ(dims.size(), 3); @@ -707,8 +703,7 @@ ReferenceUtil::ReduceToRowArray2D( } /* static */ std::unique_ptr> ReferenceUtil::Reduce3DTo2D( - const Array3D& array, float init, - tensorflow::gtl::ArraySlice dims, + const Array3D& array, float init, absl::Span dims, const std::function& reduce_function) { CHECK_EQ(dims.size(), 1); int64 rows = dims[0] == 0 ? array.n2() : array.n1(); diff --git a/tensorflow/compiler/xla/reference_util.h b/tensorflow/compiler/xla/reference_util.h index 88f853a3591c25289a8022909da8cdd4437883a6..8654fbb9b5e16c5ac13cb29aafeef8d142dbe39f 100644 --- a/tensorflow/compiler/xla/reference_util.h +++ b/tensorflow/compiler/xla/reference_util.h @@ -23,13 +23,13 @@ limitations under the License. #include #include "absl/memory/memory.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/padding.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" @@ -144,8 +144,7 @@ class ReferenceUtil { // Returns the result of reducing the 4D array to a vector, reducing away // the dimensions specified in dims. static std::vector Reduce4DTo1D( - const Array4D& array, float init, - tensorflow::gtl::ArraySlice dims, + const Array4D& array, float init, absl::Span dims, const std::function& reduce_function); // Broadcast 1D dimension to 4D, from the dimension `broadcast_from_dim`. @@ -156,8 +155,7 @@ class ReferenceUtil { // Returns the result of reducing the 3D array to a 2D array, reducing away // the dimensions specified in dims. static std::unique_ptr> Reduce3DTo2D( - const Array3D& array, float init, - tensorflow::gtl::ArraySlice dims, + const Array3D& array, float init, absl::Span dims, const std::function& reduce_function); // Applies map_function to each element in the input (2D array) and returns @@ -179,47 +177,41 @@ class ReferenceUtil { // Windowed reductions with Add as the function to apply. static std::unique_ptr> ReduceWindow1DAdd( - const tensorflow::gtl::ArraySlice& operand, float init, - const tensorflow::gtl::ArraySlice& window, - const tensorflow::gtl::ArraySlice& stride, Padding padding); + absl::Span operand, float init, + absl::Span window, absl::Span stride, + Padding padding); static std::unique_ptr> ReduceWindow2DAdd( - const Array2D& operand, float init, - const tensorflow::gtl::ArraySlice& window, - const tensorflow::gtl::ArraySlice& stride, Padding padding); + const Array2D& operand, float init, absl::Span window, + absl::Span stride, Padding padding); static std::unique_ptr> ReduceWindow3DAdd( - const Array3D& operand, float init, - const tensorflow::gtl::ArraySlice& window, - const tensorflow::gtl::ArraySlice& stride, Padding padding); + const Array3D& operand, float init, absl::Span window, + absl::Span stride, Padding padding); static std::unique_ptr> ReduceWindow4DAdd( - const Array4D& operand, float init, - const tensorflow::gtl::ArraySlice& window, - const tensorflow::gtl::ArraySlice& stride, Padding padding); + const Array4D& operand, float init, absl::Span window, + absl::Span stride, Padding padding); // Windowed reductions with a generic reduce function. static std::unique_ptr> ReduceWindow1DGeneric( - const tensorflow::gtl::ArraySlice& operand, float init, + absl::Span operand, float init, const std::function& reduce_func, - const tensorflow::gtl::ArraySlice& window, - const tensorflow::gtl::ArraySlice& stride, - const tensorflow::gtl::ArraySlice>& padding); + absl::Span window, absl::Span stride, + absl::Span> padding); static std::unique_ptr> ReduceWindow2DGeneric( const Array2D& operand, float init, const std::function& reduce_func, - const tensorflow::gtl::ArraySlice& window, - const tensorflow::gtl::ArraySlice& stride, - const tensorflow::gtl::ArraySlice>& padding); + absl::Span window, absl::Span stride, + absl::Span> padding); static std::unique_ptr> ReduceWindow4DGeneric( const Array4D& operand, float init, const std::function& reduce_func, - const tensorflow::gtl::ArraySlice& window, - const tensorflow::gtl::ArraySlice& stride, Padding padding); + absl::Span window, absl::Span stride, + Padding padding); // With arbitrary padding. static std::unique_ptr> ReduceWindow4DGeneric( const Array4D& operand, float init, const std::function& reduce_func, - const tensorflow::gtl::ArraySlice& window, - const tensorflow::gtl::ArraySlice& stride, - const tensorflow::gtl::ArraySlice>& padding); + absl::Span window, absl::Span stride, + absl::Span> padding); // Batch normalize data. static std::unique_ptr> BatchNorm4D( @@ -232,8 +224,8 @@ class ReferenceUtil { // TODO(b/74533103) Switch tests to evaluator and remove this implementation. static std::unique_ptr> SelectAndScatter4DGePlus( const Array4D& operand, const Array4D& source, float init, - const tensorflow::gtl::ArraySlice& window, - const tensorflow::gtl::ArraySlice& stride, bool same_padding); + absl::Span window, absl::Span stride, + bool same_padding); // Concatenates the lhs and rhs arrays along the concatenate_dimension. // E.g. if concatenate_dimension is 0, the "n1"/height dimension is @@ -334,8 +326,8 @@ class ReferenceUtil { // Slices with index clamping template - static std::vector ClampSlice1D( - const tensorflow::gtl::ArraySlice& input, int64 start, int64 size) { + static std::vector ClampSlice1D(absl::Span input, int64 start, + int64 size) { start = std::min(std::max(0, start), input.size() - size); std::vector result; for (int64 i = 0; i < size; ++i) { @@ -633,7 +625,7 @@ class ReferenceUtil { Array4D result(output_bounds[0], output_bounds[1], output_bounds[2], output_bounds[3]); result.Each( - [&](tensorflow::gtl::ArraySlice indices, NativeT* value) { + [&](absl::Span indices, NativeT* value) { for (int i = 0; i < 4; ++i) { bool in_low_padding = indices[i] < pad_low[i]; bool in_high_padding = indices[i] >= output_bounds[i] - pad_high[i]; diff --git a/tensorflow/compiler/xla/reference_util_test.cc b/tensorflow/compiler/xla/reference_util_test.cc index 3ec0192148492c2516bf1c14fd4b960b08014388..a1b0f4045ff071454451f9fe3942ac974f4f47ac 100644 --- a/tensorflow/compiler/xla/reference_util_test.cc +++ b/tensorflow/compiler/xla/reference_util_test.cc @@ -55,7 +55,7 @@ TEST_F(ReferenceUtilTest, TransposeArray2D) { auto result = ReferenceUtil::TransposeArray2D(*matrix_); auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result); LiteralTestUtil::ExpectR2Near({{1.f, 4.f}, {2.f, 5.f}, {3.f, 6.f}}, - *actual_literal, ErrorSpec(0.0001)); + actual_literal, ErrorSpec(0.0001)); } TEST_F(ReferenceUtilTest, MatmulArray2D) { @@ -67,14 +67,14 @@ TEST_F(ReferenceUtilTest, MatmulArray2D) { auto result = ReferenceUtil::MatmulArray2D(*matrix_, rhs); auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result); LiteralTestUtil::ExpectR2Near({{58.f, 64.f}, {139.f, 154.f}}, - *actual_literal, ErrorSpec(0.0001)); + actual_literal, ErrorSpec(0.0001)); } TEST_F(ReferenceUtilTest, ReduceToColArray2D) { auto add = [](float lhs, float rhs) { return lhs + rhs; }; auto result = ReferenceUtil::ReduceToColArray2D(*matrix_, 0.0f, add); auto actual_literal = LiteralUtil::CreateR1(*result); - LiteralTestUtil::ExpectR1Near({6.f, 15.f}, *actual_literal, + LiteralTestUtil::ExpectR1Near({6.f, 15.f}, actual_literal, ErrorSpec(0.0001)); } @@ -82,7 +82,7 @@ TEST_F(ReferenceUtilTest, ReduceToRowArray2D) { auto add = [](float lhs, float rhs) { return lhs + rhs; }; auto result = ReferenceUtil::ReduceToRowArray2D(*matrix_, 0.0f, add); auto actual_literal = LiteralUtil::CreateR1(*result); - LiteralTestUtil::ExpectR1Near({5.f, 7.f, 9.f}, *actual_literal, + LiteralTestUtil::ExpectR1Near({5.f, 7.f, 9.f}, actual_literal, ErrorSpec(0.0001)); } @@ -90,14 +90,14 @@ TEST_F(ReferenceUtilTest, Reduce4Dto1DZeroSizedArray) { auto result = LiteralUtil::CreateR1(ReferenceUtil::Reduce4DTo1D( Array4D(1, 0, 1, 1), /*init=*/0, /*dims=*/{0, 1, 2}, [](float a, float b) { return a + b; })); - LiteralTestUtil::ExpectR1Equal({0}, *result); + LiteralTestUtil::ExpectR1Equal({0}, result); } TEST_F(ReferenceUtilTest, MapArray2D) { auto identity = [](float value) { return log(exp(value)); }; auto result = ReferenceUtil::MapArray2D(*matrix_, identity); auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result); - LiteralTestUtil::ExpectR2NearArray2D(*matrix_, *actual_literal, + LiteralTestUtil::ExpectR2NearArray2D(*matrix_, actual_literal, ErrorSpec(0.0001)); } @@ -108,7 +108,7 @@ TEST_F(ReferenceUtilTest, MapWithIndexArray2D) { auto result = ReferenceUtil::MapWithIndexArray2D(*matrix_, add_index); auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result); LiteralTestUtil::ExpectR2Near({{1.f, 3.f, 5.f}, {5.f, 7.f, 9.f}}, - *actual_literal, ErrorSpec(0.0001)); + actual_literal, ErrorSpec(0.0001)); } TEST_F(ReferenceUtilTest, MapArray4D) { @@ -121,7 +121,7 @@ TEST_F(ReferenceUtilTest, MapArray4D) { Array4D expected(/*planes=*/2, /*depth=*/3, /*height=*/4, /*width=*/5); expected.FillWithMultiples(2.0f); - LiteralTestUtil::ExpectR4NearArray4D(expected, *actual_literal, + LiteralTestUtil::ExpectR4NearArray4D(expected, actual_literal, ErrorSpec(0.0001)); } @@ -138,7 +138,7 @@ TEST_F(ReferenceUtilTest, MapWithIndexArray4D) { Array4D expected(/*planes=*/2, /*depth=*/3, /*height=*/4, /*width=*/5); expected.Fill(0.0f); - LiteralTestUtil::ExpectR4NearArray4D(expected, *actual_literal, + LiteralTestUtil::ExpectR4NearArray4D(expected, actual_literal, ErrorSpec(0.0001)); } @@ -146,16 +146,16 @@ TEST_F(ReferenceUtilTest, SliceArray2D) { auto result = ReferenceUtil::Slice2D(*matrix_, {{0, 0}}, {{2, 2}}, {{1, 1}}); auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result); - LiteralTestUtil::ExpectR2Near({{1.f, 2.f}, {4.f, 5.f}}, - *actual_literal, ErrorSpec(0.0001)); + LiteralTestUtil::ExpectR2Near({{1.f, 2.f}, {4.f, 5.f}}, actual_literal, + ErrorSpec(0.0001)); } TEST_F(ReferenceUtilTest, SliceStridedArray2D) { auto result = ReferenceUtil::Slice2D(*matrix_, {{0, 0}}, {{2, 3}}, {{1, 2}}); auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result); - LiteralTestUtil::ExpectR2Near({{1.f, 3.f}, {4.f, 6.f}}, - *actual_literal, ErrorSpec(0.0001)); + LiteralTestUtil::ExpectR2Near({{1.f, 3.f}, {4.f, 6.f}}, actual_literal, + ErrorSpec(0.0001)); } TEST_F(ReferenceUtilTest, SliceArray3D) { @@ -167,7 +167,7 @@ TEST_F(ReferenceUtilTest, SliceArray3D) { auto actual_literal = LiteralUtil::CreateR3FromArray3D(*result); LiteralTestUtil::ExpectR3Near( - {{{0.f, 1.f}, {4.f, 5.f}}, {{12.f, 13.f}, {16.f, 17.f}}}, *actual_literal, + {{{0.f, 1.f}, {4.f, 5.f}}, {{12.f, 13.f}, {16.f, 17.f}}}, actual_literal, ErrorSpec(0.0001)); } @@ -180,8 +180,8 @@ TEST_F(ReferenceUtilTest, SliceStridedArray3D) { auto actual_literal = LiteralUtil::CreateR3FromArray3D(*result); LiteralTestUtil::ExpectR3Near( - {{{0.f, 2.f}, {8.f, 10.f}}, {{12.f, 14.f}, {20.f, 22.f}}}, - *actual_literal, ErrorSpec(0.0001)); + {{{0.f, 2.f}, {8.f, 10.f}}, {{12.f, 14.f}, {20.f, 22.f}}}, actual_literal, + ErrorSpec(0.0001)); } TEST_F(ReferenceUtilTest, SliceArray4D) { @@ -194,7 +194,7 @@ TEST_F(ReferenceUtilTest, SliceArray4D) { LiteralTestUtil::ExpectR4Near( {{{{60.f, 61.f}, {65.f, 66.f}}, {{80.f, 81.f}, {85.f, 86.f}}}}, - *actual_literal, ErrorSpec(0.0001)); + actual_literal, ErrorSpec(0.0001)); } TEST_F(ReferenceUtilTest, SliceStridedArray4D) { @@ -208,7 +208,7 @@ TEST_F(ReferenceUtilTest, SliceStridedArray4D) { LiteralTestUtil::ExpectR4Near( {{{{60.f, 62.f, 64.f}, {70.f, 72.f, 74.f}}, {{100.f, 102.f, 104.f}, {110.f, 112.f, 114.f}}}}, - *actual_literal, ErrorSpec(0.0001)); + actual_literal, ErrorSpec(0.0001)); } TEST_F(ReferenceUtilTest, ConvArray3DWithSamePadding) { @@ -220,7 +220,7 @@ TEST_F(ReferenceUtilTest, ConvArray3DWithSamePadding) { auto actual_literal = LiteralUtil::CreateR3FromArray3D(*actual); - LiteralTestUtil::ExpectR3NearArray3D(expected, *actual_literal, + LiteralTestUtil::ExpectR3NearArray3D(expected, actual_literal, ErrorSpec(0.0001)); } @@ -233,7 +233,7 @@ TEST_F(ReferenceUtilTest, ConvArray3DWithValidPadding) { auto actual_literal = LiteralUtil::CreateR3FromArray3D(*actual); - LiteralTestUtil::ExpectR3NearArray3D(expected, *actual_literal, + LiteralTestUtil::ExpectR3NearArray3D(expected, actual_literal, ErrorSpec(0.0001)); } @@ -268,7 +268,7 @@ TEST_F(ReferenceUtilTest, ConvWithSamePadding) { auto actual_literal = LiteralUtil::CreateR4FromArray4D(*actual); - LiteralTestUtil::ExpectR4NearArray4D(expected, *actual_literal, + LiteralTestUtil::ExpectR4NearArray4D(expected, actual_literal, ErrorSpec(0.0001)); } @@ -302,7 +302,7 @@ TEST_F(ReferenceUtilTest, ConvWithValidPadding) { auto actual_literal = LiteralUtil::CreateR4FromArray4D(*actual); - LiteralTestUtil::ExpectR4NearArray4D(expected, *actual_literal, + LiteralTestUtil::ExpectR4NearArray4D(expected, actual_literal, ErrorSpec(0.0001)); } @@ -358,7 +358,7 @@ TEST_F(ReferenceUtilTest, ConvGeneralDimensionsWithSamePadding) { auto actual_literal = LiteralUtil::CreateR4FromArray4D(*actual); - LiteralTestUtil::ExpectR4NearArray4D(expected, *actual_literal, + LiteralTestUtil::ExpectR4NearArray4D(expected, actual_literal, ErrorSpec(0.0001)); } @@ -411,7 +411,7 @@ TEST_F(ReferenceUtilTest, ConvGeneralDimensionsWithValidPadding) { auto actual_literal = LiteralUtil::CreateR4FromArray4D(*actual); - LiteralTestUtil::ExpectR4NearArray4D(expected, *actual_literal, + LiteralTestUtil::ExpectR4NearArray4D(expected, actual_literal, ErrorSpec(0.0001)); } @@ -424,7 +424,7 @@ TEST_F(ReferenceUtilTest, ApplyElementwise2D) { [](float x, float y, float z) { return 100 * x + 10 * y + z; }, a, b, c); auto actual_literal = LiteralUtil::CreateR2FromArray2D(*actual); LiteralTestUtil::ExpectR2Near({{300.f, 600.f}, {900.f, 1200.f}}, - *actual_literal, ErrorSpec(0.0001)); + actual_literal, ErrorSpec(0.0001)); } } // namespace diff --git a/tensorflow/compiler/xla/rpc/BUILD b/tensorflow/compiler/xla/rpc/BUILD index 44b22a5586dee3f7dd8ea0edbf9deb2090986ac8..3abb3855a42b8b5222115262448d359da3a80e87 100644 --- a/tensorflow/compiler/xla/rpc/BUILD +++ b/tensorflow/compiler/xla/rpc/BUILD @@ -34,15 +34,25 @@ cc_library( ], ) -tf_cc_binary( - name = "grpc_service_main_cpu", +cc_library( + name = "grpc_service_main_library", srcs = ["grpc_service_main.cc"], deps = [ ":grpc_service", "//tensorflow:grpc++", "//tensorflow/compiler/xla/service:cpu_plugin", + "//tensorflow/compiler/xla/service:platform_util", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", + "@com_google_absl//absl/strings:str_format", + ], +) + +tf_cc_binary( + name = "grpc_service_main_cpu", + deps = [ + ":grpc_service_main_library", + "//tensorflow/compiler/xla/service:cpu_plugin", ], ) @@ -62,6 +72,7 @@ tf_cc_test( "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", + "@com_google_absl//absl/strings:str_format", ], ) diff --git a/tensorflow/compiler/xla/rpc/grpc_client_test.cc b/tensorflow/compiler/xla/rpc/grpc_client_test.cc index 67886761813f0bb45a600661b017be91ffeade73..84fe5b17d10fba8c9f44314bec2b827e98ff6b33 100644 --- a/tensorflow/compiler/xla/rpc/grpc_client_test.cc +++ b/tensorflow/compiler/xla/rpc/grpc_client_test.cc @@ -23,12 +23,12 @@ limitations under the License. #include "grpcpp/create_channel.h" #include "grpcpp/security/credentials.h" +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/rpc/grpc_stub.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/core/lib/io/path.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/net.h" #include "tensorflow/core/platform/subprocess.h" @@ -46,7 +46,7 @@ class GRPCClientTestBase : public ::testing::Test { int port = tensorflow::internal::PickUnusedPortOrDie(); subprocess_.SetProgram( service_main_path, - {service_main_path, tensorflow::strings::Printf("--port=%d", port)}); + {service_main_path, absl::StrFormat("--port=%d", port)}); subprocess_.SetChannelAction(tensorflow::CHAN_STDOUT, tensorflow::ACTION_DUPPARENT); subprocess_.SetChannelAction(tensorflow::CHAN_STDERR, @@ -54,9 +54,8 @@ class GRPCClientTestBase : public ::testing::Test { CHECK(subprocess_.Start()); LOG(INFO) << "Launched subprocess"; - auto channel = - ::grpc::CreateChannel(tensorflow::strings::Printf("localhost:%d", port), - ::grpc::InsecureChannelCredentials()); + auto channel = ::grpc::CreateChannel(absl::StrFormat("localhost:%d", port), + ::grpc::InsecureChannelCredentials()); channel->WaitForConnected(gpr_time_add( gpr_now(GPR_CLOCK_REALTIME), gpr_time_from_seconds(10, GPR_TIMESPAN))); LOG(INFO) << "Channel to server is connected on port " << port; @@ -96,12 +95,11 @@ TEST_F(GRPCClientTestBase, AxpyTenValues) { std::vector expected = { 1.85840735, -1.85840735, 2.28318531, -2.28318531, -6.42477796, 6.42477796, 10.56637061, -10.56637061, -14.70796327, 14.70796327}; - std::unique_ptr expected_literal = - LiteralUtil::CreateR1(expected); + Literal expected_literal = LiteralUtil::CreateR1(expected); TF_ASSERT_OK_AND_ASSIGN(auto computation, builder.Build()); TF_ASSERT_OK_AND_ASSIGN(auto result_literal, client_->ExecuteAndTransfer( computation, {}, nullptr)); - EXPECT_TRUE(LiteralTestUtil::Near(*expected_literal, *result_literal, + EXPECT_TRUE(LiteralTestUtil::Near(expected_literal, result_literal, ErrorSpec(0.0001))); } diff --git a/tensorflow/compiler/xla/rpc/grpc_service_main.cc b/tensorflow/compiler/xla/rpc/grpc_service_main.cc index c68c857c304138ff4318e243f66547c6acce1005..522ab99fb1feff69610af887b58f197211cdb21f 100644 --- a/tensorflow/compiler/xla/rpc/grpc_service_main.cc +++ b/tensorflow/compiler/xla/rpc/grpc_service_main.cc @@ -18,8 +18,9 @@ limitations under the License. #include "grpcpp/security/server_credentials.h" #include "grpcpp/server.h" #include "grpcpp/server_builder.h" +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/rpc/grpc_service.h" -#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/util/command_line_flags.h" @@ -29,8 +30,15 @@ namespace { int RealMain(int argc, char** argv) { int32 port = 1685; + bool any_address = false; + string platform_str; std::vector flag_list = { - tensorflow::Flag("port", &port, "port to listen on"), + tensorflow::Flag("platform", &platform_str, + "The XLA platform this service should be bound to"), + tensorflow::Flag("port", &port, "The TCP port to listen on"), + tensorflow::Flag( + "any", &any_address, + "Whether to listen to any host address or simply localhost"), }; string usage = tensorflow::Flags::Usage(argv[0], flag_list); bool parsed_values_ok = tensorflow::Flags::Parse(&argc, argv, flag_list); @@ -40,19 +48,24 @@ int RealMain(int argc, char** argv) { } tensorflow::port::InitMain(argv[0], &argc, &argv); + se::Platform* platform = nullptr; + if (!platform_str.empty()) { + platform = PlatformUtil::GetPlatform(platform_str).ValueOrDie(); + } std::unique_ptr service = - xla::GRPCService::NewService().ConsumeValueOrDie(); + xla::GRPCService::NewService(platform).ConsumeValueOrDie(); ::grpc::ServerBuilder builder; - string server_address(tensorflow::strings::Printf("localhost:%d", port)); + string server_address( + absl::StrFormat("%s:%d", any_address ? "[::]" : "localhost", port)); + builder.SetMaxReceiveMessageSize(INT_MAX); builder.AddListeningPort(server_address, ::grpc::InsecureServerCredentials()); builder.RegisterService(service.get()); std::unique_ptr<::grpc::Server> server(builder.BuildAndStart()); LOG(INFO) << "Server listening on " << server_address; server->Wait(); - return 0; } diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 47d376c8ac8b3522757dd7b728394151b1c5ffa6..7d03eba800f6882efae448e3e41c488c513f4a84 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -69,6 +69,7 @@ cc_library( "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/types:span", ], ) @@ -86,6 +87,7 @@ tf_cc_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", ], @@ -103,6 +105,7 @@ cc_library( "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/types:span", ], ) @@ -121,6 +124,7 @@ tf_cc_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", ], @@ -142,6 +146,8 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", ], ) @@ -157,6 +163,7 @@ tf_cc_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep ], @@ -177,7 +184,10 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", ], ) @@ -194,6 +204,7 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep "//tensorflow/core:lib", + "@com_google_absl//absl/types:span", ], ) @@ -229,6 +240,7 @@ cc_library( hdrs = ["hlo_evaluator.h"], deps = [ ":hlo", + ":hlo_casting_utils", ":hlo_query", ":shape_inference", "//tensorflow/compiler/xla:literal", @@ -242,9 +254,11 @@ cc_library( "//tensorflow/core:lib", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/container:node_hash_map", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", ], ) @@ -280,11 +294,14 @@ cc_library( srcs = [ "dfs_hlo_visitor.cc", "hlo_computation.cc", + "hlo_input_output_alias_config.cc", "hlo_instruction.cc", "hlo_instructions.cc", "hlo_module.cc", "hlo_opcode.cc", + "hlo_schedule.cc", "hlo_sharding.cc", + "hlo_sharding_metadata.cc", ], hdrs = [ "dfs_hlo_visitor.h", @@ -292,11 +309,14 @@ cc_library( "hlo_clone_context.h", "hlo_computation.h", "hlo_domain_metadata.h", + "hlo_input_output_alias_config.h", "hlo_instruction.h", "hlo_instructions.h", "hlo_module.h", "hlo_opcode.h", + "hlo_schedule.h", "hlo_sharding.h", + "hlo_sharding_metadata.h", ], deps = [ ":hlo_casting_utils", @@ -321,9 +341,14 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", ], ) @@ -339,6 +364,7 @@ tf_cc_test( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", ], @@ -349,8 +375,11 @@ cc_library( hdrs = ["pattern_matcher.h"], deps = [ ":hlo", + ":hlo_casting_utils", + "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "@com_google_absl//absl/strings", + "@com_google_absl//absl/utility", ], ) @@ -376,6 +405,8 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/types:span", ], ) @@ -388,6 +419,7 @@ tf_cc_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], ) @@ -433,6 +465,7 @@ tf_cc_test( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], ) @@ -463,8 +496,11 @@ cc_library( "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", ], ) @@ -482,6 +518,7 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", ], @@ -530,6 +567,7 @@ tf_cc_test( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", @@ -552,6 +590,7 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", ], @@ -567,6 +606,7 @@ cc_library( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "@com_google_absl//absl/strings", @@ -594,6 +634,7 @@ cc_library( "//third_party/eigen3", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", ], ) @@ -637,6 +678,8 @@ cc_library( "//tensorflow/core:stream_executor_no_cuda", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", ], alwayslink = 1, ) @@ -671,6 +714,8 @@ cc_library( "//tensorflow/core:stream_executor_no_cuda", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", ], ) @@ -744,8 +789,11 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", ], ) @@ -781,9 +829,11 @@ cc_library( ":hlo_execution_profile", ":hlo_graph_dumper", ":hlo_proto", + ":maybe_owning_device_memory", ":shaped_buffer", ":stream_pool", "//tensorflow/compiler/xla:executable_run_options", + "//tensorflow/compiler/xla:shape_tree", "//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", @@ -795,6 +845,9 @@ cc_library( "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/stream_executor", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", + "@com_google_absl//absl/types:variant", ], ) @@ -807,12 +860,14 @@ cc_library( ":executable", ":hlo", ":hlo_module_config", + ":hlo_module_group", ":logical_buffer", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/types:span", ], ) @@ -844,6 +899,7 @@ cc_library( "//tensorflow/core:stream_executor_no_cuda", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", ], ) @@ -862,6 +918,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", ], @@ -900,6 +957,7 @@ cc_library( "//tensorflow/core:lib_internal", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", ], ) @@ -910,6 +968,8 @@ cc_library( deps = [ "//tensorflow/compiler/xla:types", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", ], ) @@ -945,7 +1005,10 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", ], ) @@ -962,6 +1025,7 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", "@com_google_absl//absl/memory", ], ) @@ -979,8 +1043,8 @@ cc_library( ":buffer_value_containers", ":heap_simulator", ":hlo", + ":hlo_memory_scheduler", ":hlo_proto", - ":hlo_scheduling", ":logical_buffer", ":tuple_points_to_analysis", "//tensorflow/compiler/xla:shape_util", @@ -990,8 +1054,12 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", ], ) @@ -1006,8 +1074,8 @@ tf_cc_test( ":cpu_plugin", ":flatten_call_graph", ":hlo", + ":hlo_memory_scheduler", ":hlo_ordering", - ":hlo_scheduling", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", @@ -1017,8 +1085,10 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", + "//tensorflow/core:test", "@com_google_absl//absl/memory", ], ) @@ -1039,7 +1109,9 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", ], ) @@ -1050,14 +1122,15 @@ tf_cc_test( deps = [ ":hlo", ":hlo_dataflow_analysis", + ":hlo_memory_scheduler", ":hlo_ordering", - ":hlo_scheduling", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", ], ) @@ -1075,6 +1148,8 @@ cc_library( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", ], ) @@ -1092,12 +1167,47 @@ tf_cc_test( "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", + "//tensorflow/core:test", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", ], ) +cc_library( + name = "hlo_module_group", + srcs = ["hlo_module_group.cc"], + hdrs = ["hlo_module_group.h"], + deps = [ + ":hlo", + ":hlo_proto", + "//tensorflow/compiler/xla:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + ], +) + +tf_cc_test( + name = "hlo_module_group_test", + srcs = ["hlo_module_group_test.cc"], + deps = [ + ":hlo", + ":hlo_matchers", + ":hlo_module_group", + ":hlo_module_group_metadata", + ":hlo_parser", + ":hlo_proto", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + cc_library( name = "hlo_module_group_metadata", srcs = ["hlo_module_group_metadata.cc"], @@ -1105,12 +1215,14 @@ cc_library( deps = [ ":hlo", ":hlo_casting_utils", + ":tuple_points_to_analysis", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", "@com_google_absl//absl/types:optional", ], @@ -1131,19 +1243,62 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + ], +) + +tf_cc_test( + name = "hlo_schedule_test", + srcs = ["hlo_schedule_test.cc"], + deps = [ + ":heap_simulator", + ":hlo", + ":hlo_dce", + ":hlo_memory_scheduler", + ":hlo_ordering", + ":hlo_parser", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", + "@com_google_absl//absl/algorithm:container", + ], +) + +tf_cc_test( + name = "hlo_input_output_alias_config_test", + srcs = ["hlo_input_output_alias_config_test.cc"], + deps = [ + ":hlo", + ":hlo_dce", + ":hlo_memory_scheduler", + ":hlo_ordering", + ":hlo_parser", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", + "@com_google_absl//absl/algorithm:container", ], ) cc_library( - name = "hlo_scheduling", - srcs = ["hlo_scheduling.cc"], - hdrs = ["hlo_scheduling.h"], + name = "hlo_memory_scheduler", + srcs = ["hlo_memory_scheduler.cc"], + hdrs = ["hlo_memory_scheduler.h"], deps = [ ":heap_simulator", ":hlo", ":hlo_ordering", + ":hlo_pass", ":logical_buffer", ":tuple_points_to_analysis", "//tensorflow/compiler/xla:shape_util", @@ -1153,25 +1308,29 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", ], ) tf_cc_test( - name = "hlo_scheduling_test", - srcs = ["hlo_scheduling_test.cc"], + name = "hlo_memory_scheduler_test", + srcs = ["hlo_memory_scheduler_test.cc"], deps = [ ":heap_simulator", ":hlo", ":hlo_dce", + ":hlo_memory_scheduler", ":hlo_ordering", ":hlo_parser", - ":hlo_scheduling", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", ], ) @@ -1186,16 +1345,27 @@ cc_library( ], ) +cc_library( + name = "fusion_queue", + hdrs = ["fusion_queue.h"], + deps = [ + ":hlo", + ], +) + cc_library( name = "instruction_fusion", srcs = ["instruction_fusion.cc"], hdrs = ["instruction_fusion.h"], deps = [ + ":fusion_queue", ":hlo", ":hlo_pass", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/memory", ], ) @@ -1221,6 +1391,8 @@ cc_library( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_pass", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", ], ) @@ -1254,6 +1426,7 @@ tf_cc_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", "@com_google_absl//absl/memory", @@ -1275,7 +1448,9 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", ], ) @@ -1327,6 +1502,7 @@ tf_cc_test( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "@com_google_absl//absl/memory", @@ -1339,6 +1515,7 @@ cc_library( hdrs = ["algebraic_simplifier.h"], deps = [ ":hlo", + ":hlo_casting_utils", ":hlo_creation_utils", ":hlo_pass", ":hlo_query", @@ -1356,6 +1533,7 @@ cc_library( "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", ], ) @@ -1365,6 +1543,7 @@ tf_cc_test( deps = [ ":algebraic_simplifier", ":hlo", + ":hlo_casting_utils", ":hlo_matchers", ":hlo_pass", "//tensorflow/compiler/xla:literal", @@ -1525,6 +1704,8 @@ cc_library( ":while_loop_analysis", "//tensorflow/compiler/xla:statusor", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", ], @@ -1556,6 +1737,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", ], ) @@ -1640,6 +1822,7 @@ tf_cc_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/core:test", ], ) @@ -1680,40 +1863,6 @@ tf_cc_test( ], ) -cc_library( - name = "inliner", - srcs = ["inliner.cc"], - hdrs = ["inliner.h"], - deps = [ - ":hlo", - ":hlo_pass", - ":hlo_query", - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:types", - "//tensorflow/core:lib", - ], -) - -tf_cc_test( - name = "inliner_test", - srcs = ["inliner_test.cc"], - deps = [ - ":cpu_plugin", - ":hlo", - ":hlo_matchers", - ":inliner", - "//tensorflow/compiler/xla:literal", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:test", - "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:literal_test_util", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "@com_google_absl//absl/memory", - ], -) - cc_library( name = "computation_placer", srcs = ["computation_placer.cc"], @@ -1746,6 +1895,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", ], ) @@ -1783,6 +1933,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "@com_google_absl//absl/types:span", ], ) @@ -1881,6 +2032,9 @@ tf_cc_test( srcs = ["hlo_module_test.cc"], deps = [ ":hlo", + ":hlo_matchers", + ":hlo_memory_scheduler", + ":hlo_parser", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", @@ -1889,7 +2043,9 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", + "//tensorflow/core:test", "@com_google_absl//absl/memory", + "@com_google_absl//absl/types:span", ], ) @@ -1906,6 +2062,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", ], ) @@ -1917,6 +2074,7 @@ cc_library( ":logical_buffer", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "@com_google_absl//absl/container:flat_hash_set", ], ) @@ -1934,6 +2092,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", ], ) @@ -1951,8 +2110,10 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", ], ) @@ -1971,9 +2132,11 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", ], ) @@ -2053,6 +2216,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", ], ) @@ -2074,7 +2238,10 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", ], ) @@ -2095,6 +2262,7 @@ tf_cc_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", @@ -2132,9 +2300,13 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", ], ) @@ -2153,6 +2325,7 @@ tf_cc_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", @@ -2185,8 +2358,12 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", ], ) @@ -2209,6 +2386,8 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", ], ) @@ -2229,6 +2408,7 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/core:test", ], ) @@ -2277,7 +2457,9 @@ cc_library( ":hlo_pass", ":shape_inference", "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", ], @@ -2290,6 +2472,7 @@ tf_cc_test( ":hlo", ":hlo_parser", ":hlo_verifier", + ":layout_assignment", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", @@ -2308,12 +2491,11 @@ cc_library( ":buffer_liveness", ":buffer_value", ":call_graph", - ":copy_insertion", ":flatten_call_graph", ":hlo", ":hlo_dce", + ":hlo_memory_scheduler", ":hlo_ordering", - ":hlo_scheduling", ":logical_buffer", ":tuple_points_to_analysis", "//tensorflow/compiler/xla:shape_util", @@ -2323,8 +2505,11 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", ], ) @@ -2341,6 +2526,7 @@ tf_cc_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", ], @@ -2407,9 +2593,11 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/types:span", ], ) @@ -2421,6 +2609,7 @@ cc_library( ], deps = [ ":hlo", + ":hlo_module_group", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", @@ -2446,8 +2635,31 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + ], +) + +tf_cc_test( + name = "hlo_pass_pipeline_test", + srcs = ["hlo_pass_pipeline_test.cc"], + deps = [ + ":hlo", + ":hlo_parser", + ":hlo_pass_pipeline", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tests:test_utils", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + "//tensorflow/core:test", ], ) @@ -2464,6 +2676,8 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", ], ) @@ -2482,6 +2696,7 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:lib", @@ -2520,6 +2735,7 @@ tf_cc_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], @@ -2535,21 +2751,8 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", - "@com_google_absl//absl/memory", - ], -) - -cc_library( - name = "hlo_sharding_metadata", - srcs = ["hlo_sharding_metadata.cc"], - hdrs = [ - "hlo_sharding_metadata.h", - ], - deps = [ - ":hlo", - "//tensorflow/compiler/xla:shape_tree", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", ], ) @@ -2604,7 +2807,6 @@ tf_cc_test( ":hlo_domain_isolator", ":hlo_domain_remover", ":hlo_parser", - ":hlo_sharding_metadata", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:hlo_test_base", @@ -2658,6 +2860,22 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "maybe_owning_device_memory", + srcs = [ + "maybe_owning_device_memory.cc", + ], + hdrs = [ + "maybe_owning_device_memory.h", + ], + deps = [ + ":device_memory_allocator", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:variant", ], ) @@ -2667,6 +2885,7 @@ cc_library( hdrs = ["elemental_ir_emitter.h"], deps = [ ":hlo", + ":hlo_casting_utils", ":hlo_module_config", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -2675,6 +2894,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service/llvm_ir:ir_array", + "//tensorflow/compiler/xla/service/llvm_ir:ir_builder_mixin", "//tensorflow/compiler/xla/service/llvm_ir:llvm_loop", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/compiler/xla/service/llvm_ir:loop_emitter", @@ -2778,6 +2998,7 @@ tf_cc_test( deps = [ ":hlo_tfgraph_builder", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:protos_all_cc", ], @@ -2803,6 +3024,7 @@ cc_library( "//tensorflow/core:lib_internal", "//tensorflow/core:regexp_internal", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:optional", ], alwayslink = 1, @@ -2927,6 +3149,7 @@ cc_library( ":buffer_assignment", ":hlo", ":hlo_proto", + ":hlo_verifier", "//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:util", ], @@ -2960,6 +3183,7 @@ cc_library( ":hlo_pass_pipeline", "//tensorflow/compiler/xla:shape_util", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", ], ) @@ -3003,6 +3227,7 @@ cc_library( "//tensorflow/core:stream_executor_no_cuda", "//third_party/eigen3", "@com_google_absl//absl/memory", + "@com_google_absl//absl/types:span", ], ) @@ -3023,7 +3248,7 @@ cc_library( hdrs = ["tuple_util.h"], deps = [ ":hlo", - "//tensorflow/core:lib", + "@com_google_absl//absl/types:span", ], ) @@ -3081,6 +3306,8 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", ], ) @@ -3110,6 +3337,8 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:inlined_vector", ], ) @@ -3143,13 +3372,13 @@ cc_library( cc_library( name = "source_map_util", - srcs = ["source_map_util.cc"], + srcs = [], hdrs = ["source_map_util.h"], deps = [ ":executable", "//tensorflow/compiler/xla:status", - "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/strings:str_format", ], ) @@ -3165,6 +3394,8 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:ptr_util", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", @@ -3192,18 +3423,17 @@ cc_library( deps = [ ":hlo", ":hlo_lexer", - ":hlo_sharding_metadata", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", ], ) @@ -3212,6 +3442,8 @@ tf_cc_test( size = "small", srcs = ["hlo_parser_test.cc"], deps = [ + ":hlo", + ":hlo_casting_utils", ":hlo_matchers", ":hlo_parser", "//tensorflow/compiler/xla:window_util", @@ -3248,6 +3480,39 @@ cc_library( deps = ["//tensorflow/core:lib"], ) +cc_library( + name = "map_inliner", + srcs = ["map_inliner.cc"], + hdrs = ["map_inliner.h"], + deps = [ + ":hlo", + ":hlo_pass", + ":hlo_query", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:types", + "//tensorflow/core:lib", + "@com_google_absl//absl/types:span", + ], +) + +tf_cc_test( + name = "map_inliner_test", + srcs = ["map_inliner_test.cc"], + deps = [ + ":hlo", + ":hlo_matchers", + ":map_inliner", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep + "@com_google_absl//absl/memory", + ], +) + tf_cc_test( name = "hlo_casting_utils_test", srcs = ["hlo_casting_utils_test.cc"], diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index c236453fc77c4082be295156889e7be22f55152e..ca71f2cc129fc5d14e454c98a6e5ebf2e94cd7d2 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -26,13 +26,16 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "absl/types/optional.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_creation_utils.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_query.h" #include "tensorflow/compiler/xla/service/pattern_matcher.h" @@ -44,7 +47,6 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -125,6 +127,8 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { Status HandleImag(HloInstruction* imag) override; + Status HandleIota(HloInstruction* instruction) override; + Status HandleConvolution(HloInstruction* convolution) override; Status HandleDivide(HloInstruction* divide) override; @@ -201,7 +205,7 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { HloInstruction* AddReduce(HloInstruction* hlo, int64 dim) { HloInstruction* zero = computation_->AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::Zero(hlo->shape().element_type()).CloneToUnique())); + LiteralUtil::Zero(hlo->shape().element_type()).Clone())); HloComputation* AddReduce_computation = GetOrCreateScalarAddComputation(); Shape shape = ShapeUtil::DeleteDimension(dim, hlo->shape()); return computation_->AddInstruction(HloInstruction::CreateReduce( @@ -292,6 +296,14 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { return scalar_add_computation_; } + // Tries to fold a kPad in the input or filter into the convolution + // instruction's window. + StatusOr FoldConvInputPad(HloInstruction* convolution); + StatusOr FoldConvFilterPad(HloInstruction* convolution); + + // Tries to use a kDot in place of the given convolution. + StatusOr SimplifyConvToDot(HloInstruction* convolution); + // Current HloComputation instance the AlgebraicSimplifierVisitor is // traversing. HloComputation* computation_; @@ -308,7 +320,8 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { // Disable dot strength reduction on platforms where it causes a slowdown. bool enable_dot_strength_reduction_; - // Disable convolution simplification on platforms where it causes a slowdown. + // Disable convolution -> dot simplification on platforms where it causes a + // slowdown. bool enable_conv_simplification_; // Cached computation for adding two scalar F32. @@ -447,8 +460,7 @@ Status AlgebraicSimplifierVisitor::HandleCopy(HloInstruction* copy) { Status AlgebraicSimplifierVisitor::HandleConcatenate( HloInstruction* concatenate) { - tensorflow::gtl::ArraySlice operands( - concatenate->operands()); + absl::Span operands(concatenate->operands()); if (operands.size() == 1) { // Unary concatenates are useless. ReplaceInstructionIfSameShape(concatenate, operands[0]); @@ -524,7 +536,7 @@ static HloInstruction* BuildTupleConstant(HloComputation* computation, return computation->AddInstruction(HloInstruction::CreateTuple(elems)); } else { return computation->AddInstruction( - HloInstruction::CreateConstant(literal.CloneToUnique())); + HloInstruction::CreateConstant(literal.Clone())); } } @@ -543,7 +555,7 @@ Status AlgebraicSimplifierVisitor::HandleConstant(HloInstruction* constant) { // If a literal is all the same element replace it with a scalar broadcast. if (ShapeUtil::ElementsIn(constant->shape()) > 1 && constant->literal().IsAllFirst()) { - std::unique_ptr unique_scalar = absl::make_unique( + Literal unique_scalar( LiteralUtil::GetFirstScalarLiteral(constant->literal())); HloInstruction* scalar = computation_->AddInstruction( HloInstruction::CreateConstant(std::move(unique_scalar))); @@ -551,6 +563,14 @@ Status AlgebraicSimplifierVisitor::HandleConstant(HloInstruction* constant) { constant, HloInstruction::CreateBroadcast(constant->shape(), scalar, {})); } + + // If a literal is an increasing sequence from zero, replace it with an iota. + if (ShapeUtil::Rank(constant->shape()) == 1 && + ShapeUtil::ElementsIn(constant->shape()) > 1 && + constant->literal().IsR1Iota()) { + return ReplaceWithNewInstruction( + constant, HloInstruction::CreateIota(constant->shape(), 0)); + } return Status::OK(); } @@ -578,7 +598,7 @@ Status AlgebraicSimplifierVisitor::HandleSubtract(HloInstruction* sub) { namespace { template Status InvertConstant(const HloInstruction& constant, Literal* result) { - return result->Populate([&](tensorflow::gtl::ArraySlice indices) { + return result->Populate([&](absl::Span indices) { return T{1.0} / constant.literal().Get(indices); }); } @@ -665,7 +685,7 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) { return Status::OK(); } auto inverse = computation_->AddInstruction( - HloInstruction::CreateConstant((new_literal.CloneToUnique()))); + HloInstruction::CreateConstant((new_literal.Clone()))); TF_ASSIGN_OR_RETURN(auto new_divide, MakeBinaryHlo(HloOpcode::kMultiply, a, inverse)); return ReplaceInstruction(divide, new_divide); @@ -725,12 +745,25 @@ StatusOr AlgebraicSimplifierVisitor::HandleDotStrengthReduction( } const int64 rhs_kept_dim = 1 - rhs_collapsing_dim; - auto reshape_if_necessary = [&](HloInstruction* hlo) { - if (ShapeUtil::SameDimensions(hlo->shape(), dot->shape())) { + auto as_type = [&](HloInstruction* hlo, const PrimitiveType element_type) { + if (hlo->shape().element_type() == element_type) { return hlo; } - return computation_->AddInstruction( - HloInstruction::CreateReshape(dot->shape(), hlo)); + return computation_->AddInstruction(HloInstruction::CreateConvert( + ShapeUtil::ChangeElementType(hlo->shape(), element_type), hlo)); + }; + + auto reshape_if_necessary = [&](HloInstruction* hlo) { + hlo = as_type(hlo, dot->shape().element_type()); + if (!ShapeUtil::SameDimensions(hlo->shape(), dot->shape())) { + hlo = computation_->AddInstruction( + HloInstruction::CreateReshape(dot->shape(), hlo)); + } + return hlo; + }; + + auto add_reduce_in_f32 = [&](HloInstruction* hlo, const int64 dim) { + return AddReduce(as_type(hlo, F32), dim); }; auto broadcast_to_dim = [&](HloInstruction* hlo, const Shape& shape, @@ -750,7 +783,7 @@ StatusOr AlgebraicSimplifierVisitor::HandleDotStrengthReduction( if (ShapeUtil::Rank(rhs->shape()) == 1 && ShapeUtil::Rank(lhs->shape()) == 1) { TF_RETURN_IF_ERROR( - ReplaceInstruction(dot, reshape_if_necessary(AddReduce( + ReplaceInstruction(dot, reshape_if_necessary(add_reduce_in_f32( multiply(Flatten(lhs), Flatten(rhs)), 0)))); return true; } @@ -784,17 +817,17 @@ StatusOr AlgebraicSimplifierVisitor::HandleDotStrengthReduction( (ShapeUtil::Rank(lhs->shape()) == 2 && lhs->shape().dimensions(lhs_kept_dim) == 1)) { if (ShapeUtil::Rank(rhs->shape()) == 1) { - TF_RETURN_IF_ERROR(ReplaceInstruction( - dot, - reshape_if_necessary(AddReduce(multiply(Flatten(lhs), rhs), 0)))); + TF_RETURN_IF_ERROR( + ReplaceInstruction(dot, reshape_if_necessary(add_reduce_in_f32( + multiply(Flatten(lhs), rhs), 0)))); return true; } TF_RETURN_IF_ERROR(ReplaceInstruction( - dot, reshape_if_necessary( - AddReduce(multiply(broadcast_to_dim(Flatten(lhs), rhs->shape(), - rhs_collapsing_dim), - rhs), - rhs_collapsing_dim)))); + dot, reshape_if_necessary(add_reduce_in_f32( + multiply(broadcast_to_dim(Flatten(lhs), rhs->shape(), + rhs_collapsing_dim), + rhs), + rhs_collapsing_dim)))); return true; } @@ -806,7 +839,7 @@ StatusOr AlgebraicSimplifierVisitor::HandleDotStrengthReduction( (ShapeUtil::Rank(rhs->shape()) == 2 && rhs->shape().dimensions(rhs_kept_dim) == 1)) { TF_RETURN_IF_ERROR(ReplaceInstruction( - dot, reshape_if_necessary(AddReduce( + dot, reshape_if_necessary(add_reduce_in_f32( multiply(lhs, broadcast_to_dim(Flatten(rhs), lhs->shape(), lhs_collapsing_dim)), lhs_collapsing_dim)))); @@ -939,9 +972,9 @@ StatusOr AlgebraicSimplifierVisitor::OptimizeDotOfConcatHelper( new_dot_rhs = rhs_slice; } - auto* new_dot = computation_->AddInstruction(HloInstruction::CreateDot( - dot.shape(), new_dot_lhs, new_dot_rhs, new_dot_dnums)); - new_dot->set_precision_config(dot.precision_config()); + auto* new_dot = computation_->AddInstruction( + HloInstruction::CreateDot(dot.shape(), new_dot_lhs, new_dot_rhs, + new_dot_dnums, dot.precision_config())); if (add_result) { add_result = computation_->AddInstruction(HloInstruction::CreateBinary( @@ -1041,10 +1074,11 @@ StatusOr AlgebraicSimplifierVisitor::OptimizeDotOfGather( const int m = left_operand->shape().dimensions(1 - lhs_contracting_dimension); const int n = right_operand->shape().dimensions(1 - rhs_contracting_dimension); - auto memoized_shape = ShapeUtil::MakeShape(F32, {m, n}); - auto* memoized_inst = computation_->AddInstruction(HloInstruction::CreateDot( - memoized_shape, left_operand, right_operand, dnums)); - memoized_inst->set_precision_config(dot->precision_config()); + auto memoized_shape = + ShapeUtil::MakeShape(dot->shape().element_type(), {m, n}); + auto* memoized_inst = computation_->AddInstruction( + HloInstruction::CreateDot(memoized_shape, left_operand, right_operand, + dnums, dot->precision_config())); // Get pair {start, 0} or {0, start}. HloInstruction* original_start_indices = lhs_is_dynamic_slice ? lhs->mutable_operand(1) : rhs->mutable_operand(1); @@ -1089,10 +1123,12 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { HloInstruction *lhs, *rhs; CHECK(Match(dot, m::Dot(m::Op(&lhs), m::Op(&rhs)))); - // Only optimize F32 dot operations where the dot, rhs and lhs are rank 2 or - // below. - if (dot->shape().element_type() != F32 || ShapeUtil::Rank(lhs->shape()) > 2 || - ShapeUtil::Rank(rhs->shape()) > 2 || ShapeUtil::Rank(dot->shape()) > 2) { + // Only optimize F32 or BF16 dot operations where the dot, rhs and lhs are + // rank 2 or below. + if ((dot->shape().element_type() != F32 && + dot->shape().element_type() != BF16) || + ShapeUtil::Rank(lhs->shape()) > 2 || ShapeUtil::Rank(rhs->shape()) > 2 || + ShapeUtil::Rank(dot->shape()) > 2) { return Status::OK(); } @@ -1140,9 +1176,8 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { dot_dimension_numbers.add_rhs_contracting_dimensions(0); auto new_dot = computation_->AddInstruction(HloInstruction::CreateDot( ShapeUtil::PermuteDimensions({1, 0}, dot->shape()), - rhs->mutable_operand(0), lhs->mutable_operand(0), - dot_dimension_numbers)); - new_dot->set_precision_config(dot->precision_config()); + rhs->mutable_operand(0), lhs->mutable_operand(0), dot_dimension_numbers, + dot->precision_config())); return ReplaceWithNewInstruction( dot, HloInstruction::CreateTranspose(dot->shape(), new_dot, {1, 0})); } @@ -1238,9 +1273,8 @@ namespace { // return value = {1, 3} // // Precondition: input_dim_indices is sorted. -std::pair> ReshapeLeavesDimensionsUnmodified( - const HloInstruction* hlo, - tensorflow::gtl::ArraySlice input_dim_indices) { +absl::optional> ReshapeLeavesDimensionsUnmodified( + const HloInstruction* hlo, absl::Span input_dim_indices) { CHECK_EQ(HloOpcode::kReshape, hlo->opcode()); CHECK(std::is_sorted(input_dim_indices.begin(), input_dim_indices.end())); @@ -1258,11 +1292,11 @@ std::pair> ReshapeLeavesDimensionsUnmodified( } if (i >= unmodified_dims.size() || unmodified_dims[i].first != input_dim_index) { - return std::make_pair(false, std::vector()); + return absl::nullopt; } output_dim_indices.push_back(unmodified_dims[i].second); } - return std::make_pair(true, output_dim_indices); + return output_dim_indices; } // Returns true if the output of "instruction" is a permutation of the @@ -1391,6 +1425,15 @@ Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) { return Status::OK(); } + // broadcast(iota) -> iota. + if (operand->opcode() == HloOpcode::kIota) { + return ReplaceWithNewInstruction( + broadcast, + HloInstruction::CreateIota( + broadcast->shape(), + dims[Cast(operand)->iota_dimension()])); + } + // Merge two consecutive broadcasts into a single one. if (operand->opcode() == HloOpcode::kBroadcast) { std::vector new_dimensions; @@ -1445,6 +1488,19 @@ Status AlgebraicSimplifierVisitor::HandleImag(HloInstruction* imag) { return Status::OK(); } +Status AlgebraicSimplifierVisitor::HandleIota(HloInstruction* instruction) { + // iota -> zero if the iota dimension never produces an element other than + // zero. + auto* iota = Cast(instruction); + if (iota->shape().dimensions(iota->iota_dimension()) <= 1) { + auto zero = computation_->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(iota->shape().element_type()).Clone())); + return ReplaceWithNewInstruction( + iota, HloInstruction::CreateBroadcast(iota->shape(), zero, {})); + } + return Status::OK(); +} + Status AlgebraicSimplifierVisitor::HandlePad(HloInstruction* pad) { if (ShapeUtil::IsZeroElementArray(pad->operand(0)->shape())) { return ReplaceWithNewInstruction( @@ -1541,7 +1597,7 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) { CHECK(Match(power, m::Power(m::Op(&lhs), m::Op(&rhs)))); if (IsAll(rhs, 0)) { auto one = HloInstruction::CreateConstant( - LiteralUtil::One(power->shape().element_type()).CloneToUnique()); + LiteralUtil::One(power->shape().element_type()).Clone()); std::unique_ptr ones; if (ShapeUtil::IsScalar(power->shape())) { ones = std::move(one); @@ -1576,7 +1632,7 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) { VLOG(10) << "trying transform [pow(A, -1) => 1/A]: " << power->ToString(); if (IsAll(rhs, -1)) { auto* one = computation_->AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::One(rhs->shape().element_type()).CloneToUnique())); + LiteralUtil::One(rhs->shape().element_type()).Clone())); // Explicitly broadcast scalar 1 to the output shape, to avoid implicit // broadcast in divide HLO as we are trying to eliminate implicit @@ -1719,12 +1775,25 @@ Status AlgebraicSimplifierVisitor::HandleReshape(HloInstruction* reshape) { if (HloOpcode::kBroadcast == reshape->operand(0)->opcode()) { auto opt_dims = ReshapeLeavesDimensionsUnmodified( reshape, reshape->operand(0)->dimensions()); - if (opt_dims.first) { + if (opt_dims.has_value()) { return ReplaceWithNewInstruction( reshape, HloInstruction::CreateBroadcast( reshape->shape(), reshape->mutable_operand(0)->mutable_operand(0), - opt_dims.second)); + *opt_dims)); + } + } + + // reshape(iota) -> iota. + if (operand->opcode() == HloOpcode::kIota) { + auto* iota = Cast(operand); + auto opt_dims = + ReshapeLeavesDimensionsUnmodified(reshape, {iota->iota_dimension()}); + if (opt_dims.has_value()) { + CHECK_EQ(opt_dims->size(), 1); + return ReplaceWithNewInstruction( + reshape, + HloInstruction::CreateIota(reshape->shape(), opt_dims->front())); } } @@ -1821,7 +1890,7 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) { auto arg = reduce->mutable_operand(0); auto init_value = reduce->mutable_operand(1); - tensorflow::gtl::ArraySlice dimensions(reduce->dimensions()); + absl::Span dimensions(reduce->dimensions()); HloComputation* function = reduce->to_apply(); if (ShapeUtil::IsZeroElementArray(arg->shape()) || ShapeUtil::IsZeroElementArray(reduce->shape())) { @@ -1988,6 +2057,12 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow( return Status::OK(); } + // Bail on dilation. + if (window_util::HasDilation(window)) { + VLOG(10) << "Not folding pad into reduce-window as there is dilation."; + return Status::OK(); + } + VLOG(10) << "Considering folding Pad: " << pad->ToString() << "\ninto reduce-window: " << reduce_window->ToString() << (convert != nullptr @@ -2013,12 +2088,12 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow( if (pad_literal == reduce_init_literal) { return true; } - auto converted_pad_literal = pad_literal.ConvertToShape( - reduce_init_value->shape(), /*round_f32_to_bf16=*/true); + auto converted_pad_literal = + pad_literal.ConvertToShape(reduce_init_value->shape()); if (!converted_pad_literal.ok()) { return false; } - return *converted_pad_literal.ValueOrDie() == reduce_init_literal; + return converted_pad_literal.ValueOrDie() == reduce_init_literal; }; // The pad value is usually a constant, so we handle that case and do not // try to get more fancy about proving equivalence in cases beyond that. @@ -2134,7 +2209,7 @@ Status AlgebraicSimplifierVisitor::HandleSort(HloInstruction* sort) { } // If it is key/value sort, the output of sort is a tuple. return ReplaceWithNewInstruction( - sort, HloInstruction::CreateTuple({operand, sort->mutable_operand(1)})); + sort, HloInstruction::CreateTuple(sort->operands())); } return Status::OK(); } @@ -2168,40 +2243,157 @@ Status AlgebraicSimplifierVisitor::HandleTranspose(HloInstruction* transpose) { return Status::OK(); } -Status AlgebraicSimplifierVisitor::HandleConvolution( +StatusOr AlgebraicSimplifierVisitor::FoldConvInputPad( HloInstruction* convolution) { - auto lhs = convolution->mutable_operand(0); - auto rhs = convolution->mutable_operand(1); - if (ShapeUtil::IsZeroElementArray(lhs->shape()) || - ShapeUtil::IsZeroElementArray(rhs->shape())) { - return ReplaceWithNewInstruction( - convolution, - HloInstruction::CreateBroadcast( - convolution->shape(), - computation_->AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::Zero(convolution->shape().element_type()) - .CloneToUnique())), - {})); + auto* lhs = convolution->mutable_operand(0); + auto* rhs = convolution->mutable_operand(1); + const auto& window = convolution->window(); + const ConvolutionDimensionNumbers& dnums = + convolution->convolution_dimension_numbers(); + + if (lhs->opcode() != HloOpcode::kPad) { + return false; + } + + // Convolution's padding is always zero, so bail if the kPad is adding + // something other than zero. + if (!IsAll(lhs->operand(1), 0)) { + return false; + } + + const auto& padding = lhs->padding_config(); + + // Can't pad batch or feature dims. + for (int64 dim : + {dnums.input_batch_dimension(), dnums.input_feature_dimension()}) { + const auto& p = padding.dimensions(dim); + if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 || + p.interior_padding() != 0) { + return false; + } + } + + // Compute the window which is the result of merging the kPad and the + // convolution's existing window. + Window new_window = window; + for (int64 dim = 0; dim < dnums.input_spatial_dimensions_size(); ++dim) { + auto& w = *new_window.mutable_dimensions(dim); + const auto& p = padding.dimensions(dnums.input_spatial_dimensions(dim)); + // Edge padding composes with itself in the straightforward way, but + // composing interior padding is nontrivial, and we cowardly refuse to + // think about it. If we see interior padding in either the kPad or conv, + // bail if there's any sort of padding in the other. + if (p.interior_padding() != 0 && + (w.padding_low() != 0 || w.padding_high() != 0 || + w.base_dilation() != 1)) { + return false; + } + if (w.base_dilation() != 1 && + (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 || + p.interior_padding() != 0)) { + return false; + } + + w.set_padding_low(w.padding_low() + p.edge_padding_low()); + w.set_padding_high(w.padding_high() + p.edge_padding_high()); + if (p.interior_padding() != 0) { + CHECK_EQ(w.base_dilation(), 1); + w.set_base_dilation(1 + p.interior_padding()); + } + } + + auto new_conv = convolution->CloneWithNewOperands( + convolution->shape(), {lhs->mutable_operand(0), rhs}); + new_conv->set_window(new_window); + TF_RETURN_IF_ERROR( + ReplaceWithNewInstruction(convolution, std::move(new_conv))); + return true; +} + +StatusOr AlgebraicSimplifierVisitor::FoldConvFilterPad( + HloInstruction* convolution) { + auto* lhs = convolution->mutable_operand(0); + auto* rhs = convolution->mutable_operand(1); + const ConvolutionDimensionNumbers& dnums = + convolution->convolution_dimension_numbers(); + + if (rhs->opcode() != HloOpcode::kPad) { + return false; + } + + // Convolution's padding is always zero, so bail if the kPad is adding + // something other than zero. + if (!IsAll(rhs->operand(1), 0)) { + return false; + } + + const auto& padding = rhs->padding_config(); + + // Can't pad or dilate feature dims. + for (int64 dim : {dnums.kernel_input_feature_dimension(), + dnums.kernel_output_feature_dimension()}) { + const auto& p = padding.dimensions(dim); + if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 || + p.interior_padding() != 0) { + return false; + } + } + + // Compute the window which is the result of merging the kPad and the + // convolution's existing window. + Window new_window = convolution->window(); + for (int64 dim = 0; dim < dnums.kernel_spatial_dimensions_size(); ++dim) { + auto& w = *new_window.mutable_dimensions(dim); + const auto& p = padding.dimensions(dnums.kernel_spatial_dimensions(dim)); + + // We can only do this transformation if p adds dilation to the filter -- + // edge padding on the filter is not supported in conv. + if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0) { + return false; + } + + // Nothing to do if the kPad for this dim is entirely a nop. + if (p.interior_padding() == 0) { + continue; + } + + // We cowardly refuse to think about how dilation composes with itself; + // bail if both the kPad and conv have dilation on this dimension. + if (w.window_dilation() > 1) { + return false; + } + CHECK_EQ(w.window_dilation(), 1); + w.set_window_dilation(1 + p.interior_padding()); + w.set_size(rhs->operand(0)->shape().dimensions( + dnums.kernel_spatial_dimensions(dim))); } + + auto new_conv = convolution->CloneWithNewOperands( + convolution->shape(), {lhs, rhs->mutable_operand(0)}); + new_conv->set_window(new_window); + TF_RETURN_IF_ERROR( + ReplaceWithNewInstruction(convolution, std::move(new_conv))); + return true; +} + +StatusOr AlgebraicSimplifierVisitor::SimplifyConvToDot( + HloInstruction* convolution) { + auto* lhs = convolution->mutable_operand(0); + auto* rhs = convolution->mutable_operand(1); const auto& window = convolution->window(); + const ConvolutionDimensionNumbers& dnums = + convolution->convolution_dimension_numbers(); + if (!enable_conv_simplification_) { - return Status::OK(); + return false; } - // HandleConvolution tries to replace a convolution with a DOT instruction. - // - // Only add when bitcasts can be used: - // - if bitcasts are not supported, then reshapes could be used but will - // end up with another copy. - // - if bitcasts are supported, the simplifier will be called again with - // bitcasts_ == true. - // TODO(cwhipkey): b/31337498, make this layout insensitive. + // TODO(b/31337498): For now, we cowardly refuse to do this optimization in + // layout-insensitive mode, for fear of adding nontrivial reshapes. if (!is_layout_sensitive_) { - return Status::OK(); + return false; } - const ConvolutionDimensionNumbers& dnums = - convolution->convolution_dimension_numbers(); const Shape& input_shape = lhs->shape(); const Shape& filter_shape = rhs->shape(); const Shape& convolution_shape = convolution->shape(); @@ -2212,7 +2404,7 @@ Status AlgebraicSimplifierVisitor::HandleConvolution( // Require the spatial dimensions in the kernel to have a bound of one. for (int64 i = 0; i < dnums.kernel_spatial_dimensions_size(); ++i) { if (filter_shape.dimensions(dnums.kernel_spatial_dimensions(i)) != 1) { - return Status::OK(); + return false; } } @@ -2223,7 +2415,7 @@ Status AlgebraicSimplifierVisitor::HandleConvolution( // for a 1x1 window, so window dilation is no problem. if (window_util::HasStride(window) || window_util::HasPadding(window) || window_util::HasBaseDilation(window)) { - return Status::OK(); + return false; } // Also, the shapes must align for a rowmajor matmul: @@ -2249,7 +2441,7 @@ Status AlgebraicSimplifierVisitor::HandleConvolution( dnums.kernel_input_feature_dimension()) < PositionInContainer(LayoutUtil::MinorToMajor(filter_shape), dnums.kernel_output_feature_dimension()))) { - return Status::OK(); + return false; } auto add_bitcast = [&](Shape shape, HloInstruction* operand) { @@ -2291,7 +2483,7 @@ Status AlgebraicSimplifierVisitor::HandleConvolution( if (!valid_bitcast_callback_(input_shape, new_input_shape) || !valid_bitcast_callback_(filter_shape, new_filter_shape) || !valid_bitcast_callback_(dot_output_shape, convolution_shape)) { - return Status::OK(); + return false; } auto new_lhs = add_bitcast(new_input_shape, lhs); @@ -2300,10 +2492,47 @@ Status AlgebraicSimplifierVisitor::HandleConvolution( dot_dimension_numbers.add_lhs_contracting_dimensions(1); dot_dimension_numbers.add_rhs_contracting_dimensions(0); auto dot = computation_->AddInstruction(HloInstruction::CreateDot( - dot_output_shape, new_lhs, new_rhs, dot_dimension_numbers)); - dot->set_precision_config(convolution->precision_config()); + dot_output_shape, new_lhs, new_rhs, dot_dimension_numbers, + convolution->precision_config())); + + TF_RETURN_IF_ERROR( + ReplaceInstruction(convolution, add_bitcast(convolution_shape, dot))); + return true; +} + +Status AlgebraicSimplifierVisitor::HandleConvolution( + HloInstruction* convolution) { + // Zero-sized input or filter. + if (ShapeUtil::IsZeroElementArray(convolution->operand(0)->shape()) || + ShapeUtil::IsZeroElementArray(convolution->operand(1)->shape())) { + return ReplaceWithNewInstruction( + convolution, + HloInstruction::CreateBroadcast( + convolution->shape(), + computation_->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(convolution->shape().element_type()))), + {})); + } - return ReplaceInstruction(convolution, add_bitcast(convolution_shape, dot)); + // Try to merge padding/dilation of the input with the convolution's window. + TF_ASSIGN_OR_RETURN(bool folded_input_pad, FoldConvInputPad(convolution)); + if (folded_input_pad) { + return Status::OK(); + } + + // Try to merge dilation of the filter with the convolution's window. + TF_ASSIGN_OR_RETURN(bool folded_filter_pad, FoldConvFilterPad(convolution)); + if (folded_filter_pad) { + return Status::OK(); + } + + // Try to replace the convolution with a kDot instruction. + TF_ASSIGN_OR_RETURN(bool replaced_with_dot, SimplifyConvToDot(convolution)); + if (replaced_with_dot) { + return Status::OK(); + } + + return Status::OK(); } bool AlgebraicSimplifierVisitor::TransformToClampIfSameShape( diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.h b/tensorflow/compiler/xla/service/algebraic_simplifier.h index b864c372fa5877ca329d2efbbf7d747c763ae2c0..9f8d0ee88bdebcf17310cd0407b1b99e4b0a7b5f 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.h +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.h @@ -24,7 +24,7 @@ limitations under the License. namespace xla { // A pass which performs algebraic simplifications. -class AlgebraicSimplifier : public HloPassInterface { +class AlgebraicSimplifier : public HloModulePass { public: // Given shapes 'from_shape' and 'to_shape', determines if it is valid to // bitcast from 'from_shape' to 'to_shape' after considering platform diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index bb63ea26d453e52a6f39551a83a36eabe9709438..42d1f337dc22b91dcef4eb8ed4c0c57c6febeb70 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -23,8 +23,10 @@ limitations under the License. #include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_pass_fix.h" @@ -52,12 +54,7 @@ AlgebraicSimplifier::ValidBitcastCallback non_bitcasting_callback() { return [](const Shape&, const Shape&) { return false; }; } -class AlgebraicSimplifierTest : public HloVerifiedTestBase { - public: - AlgebraicSimplifierTest() - : HloVerifiedTestBase(/*layout_sensitive=*/false, - /*allow_mixed_precision=*/false) {} -}; +class AlgebraicSimplifierTest : public HloVerifiedTestBase {}; // Test that A + 0 is simplified to A TEST_F(AlgebraicSimplifierTest, AddZero) { @@ -296,6 +293,21 @@ TEST_F(AlgebraicSimplifierTest, ConstantNotToBroadcast) { EXPECT_THAT(root, op::Constant()); } +TEST_F(AlgebraicSimplifierTest, IotaToBroadcast) { + HloComputation::Builder builder(TestName()); + builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1({0.0f, 1.0f, 2.0f}))); + + auto computation = module().AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_THAT(root, op::Constant()); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + root = computation->root_instruction(); + EXPECT_THAT(root, op::Iota()); +} + // Test that A - 0 is simplified to A TEST_F(AlgebraicSimplifierTest, SubZero) { Shape r0f32 = ShapeUtil::MakeShape(F32, {}); @@ -519,7 +531,7 @@ TEST_F(AlgebraicSimplifierTest, DivideByConstant) { HloInstruction::CreateParameter(0, r1f32, "param0")); HloInstruction* constant = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1({0.f, 1.f, 2.f}))); + LiteralUtil::CreateR1({1.f, 2.f, 3.f}))); builder.AddInstruction(HloInstruction::CreateBinary(r1f32, HloOpcode::kDivide, param0, constant)); @@ -1032,7 +1044,8 @@ TEST_F(AlgebraicSimplifierTest, ZeroSizedConvolution) { dim->set_window_reversal(false); // Create add computation. builder.AddInstruction(HloInstruction::CreateConvolve( - ShapeUtil::MakeShape(F32, {3, 3, 3}), lhs, rhs, window, dnums)); + ShapeUtil::MakeShape(F32, {3, 3, 3}), lhs, rhs, /*feature_group_count=*/1, + window, dnums, DefaultPrecisionConfig(2))); module().AddEntryComputation(builder.Build()); HloPassFix simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); @@ -1826,6 +1839,126 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_4_3x2x4x2_6x8) { op::Reshape(op::Broadcast(param))); } +TEST_F(AlgebraicSimplifierTest, IotaAndReshapeMerged) { + HloComputation::Builder builder(TestName()); + auto iota = builder.AddInstruction(HloInstruction::CreateIota( + ShapeUtil::MakeShape(F32, {1, 2, 3, 7, 12, 1}), 2)); + Shape result_shape = ShapeUtil::MakeShape(F32, {2, 3, 7, 2, 1, 3, 2}); + builder.AddInstruction(HloInstruction::CreateReshape(result_shape, iota)); + + auto computation = module().AddEntryComputation(builder.Build()); + + EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota())); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + + EXPECT_THAT(computation->root_instruction(), op::Iota()); + EXPECT_TRUE( + ShapeUtil::Equal(computation->root_instruction()->shape(), result_shape)); +} + +TEST_F(AlgebraicSimplifierTest, IotaEffectiveScalar) { + HloComputation::Builder builder(TestName()); + auto iota = builder.AddInstruction( + HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {1, 1}), 0)); + auto result_shape = iota->shape(); + + auto computation = module().AddEntryComputation(builder.Build()); + + EXPECT_THAT(computation->root_instruction(), op::Iota()); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + + auto root = computation->root_instruction(); + EXPECT_THAT(root, op::Broadcast(op::Constant())); + EXPECT_EQ(0.0f, root->operand(0)->literal().GetFirstElement()); + EXPECT_TRUE( + ShapeUtil::Equal(computation->root_instruction()->shape(), result_shape)); +} + +TEST_F(AlgebraicSimplifierTest, IotaAndReshape_1_3x2_6) { + HloComputation::Builder builder(TestName()); + auto iota = builder.AddInstruction( + HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {3, 2}), 1)); + builder.AddInstruction( + HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {6}), iota)); + + auto computation = module().AddEntryComputation(builder.Build()); + + EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota())); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + EXPECT_FALSE(simplifier.Run(&module()).ValueOrDie()); + + EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota())); +} + +TEST_F(AlgebraicSimplifierTest, IotaAndReshape_4_3x2x4_6x1x1x4) { + HloComputation::Builder builder(TestName()); + auto iota = builder.AddInstruction( + HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {3, 2, 4}), 2)); + builder.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {6, 1, 1, 4}), iota)); + + HloComputation* computation = module().AddEntryComputation(builder.Build()); + + EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota())); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + + EXPECT_THAT(computation->root_instruction(), op::Iota()); + EXPECT_EQ(Cast(computation->root_instruction()) + ->iota_dimension(), + 3); +} + +TEST_F(AlgebraicSimplifierTest, IotaAndReshape_1_3x2x2_6x1x1x2) { + HloComputation::Builder builder(TestName()); + auto iota = builder.AddInstruction( + HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {3, 2, 2}), 2)); + builder.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {6, 1, 1, 2}), iota)); + + HloComputation* computation = module().AddEntryComputation(builder.Build()); + + EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota())); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + + EXPECT_THAT(computation->root_instruction(), op::Iota()); + const int64 iota_dim = + Cast(computation->root_instruction()) + ->iota_dimension(); + EXPECT_THAT(iota_dim, ::testing::AnyOf(1, 2, 3)); +} + +TEST_F(AlgebraicSimplifierTest, IotaAndReshape_4_3x2x4x2_6x8) { + HloComputation::Builder builder(TestName()); + auto iota = builder.AddInstruction( + HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {3, 2, 4, 2}), 2)); + builder.AddInstruction( + HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {6, 8}), iota)); + + HloComputation* computation = module().AddEntryComputation(builder.Build()); + + EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota())); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + EXPECT_FALSE(simplifier.Run(&module()).ValueOrDie()); + + EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota())); +} + TEST_F(AlgebraicSimplifierTest, RemoveNoopPad) { HloComputation::Builder builder(TestName()); HloInstruction* param = @@ -2000,16 +2133,283 @@ TEST_F(AlgebraicSimplifierTest, ReplaceEffectiveScalarKeyValueSortWithTuple) { Shape values_shape = ShapeUtil::MakeShape(S32, {5, 0}); auto keys = builder.AddInstruction( HloInstruction::CreateParameter(0, keys_shape, "keys")); - auto values = builder.AddInstruction( - HloInstruction::CreateParameter(1, values_shape, "values")); + auto values0 = builder.AddInstruction( + HloInstruction::CreateParameter(1, values_shape, "values0")); + auto values1 = builder.AddInstruction( + HloInstruction::CreateParameter(2, values_shape, "values1")); builder.AddInstruction(HloInstruction::CreateSort( - ShapeUtil::MakeTupleShape({keys_shape, values_shape}), 0, keys, values)); + ShapeUtil::MakeTupleShape({keys_shape, values_shape, values_shape}), 0, + keys, {values0, values1})); auto module = CreateNewModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), op::Tuple(keys, values)); + EXPECT_THAT(computation->root_instruction(), + op::Tuple(keys, values0, values1)); +} + +// Used for TEST_Ps that test merging (or not) of a kPad instruction into a +// convolution's Window. +struct ConvPaddingTestcase { + ConvPaddingTestcase(absl::string_view padding, + absl::string_view orig_conv_window, + absl::string_view expected_conv_window) + : ConvPaddingTestcase(padding, orig_conv_window, expected_conv_window, + /*pad_value=*/0) {} + + ConvPaddingTestcase(absl::string_view padding, + absl::string_view orig_conv_window, + absl::string_view expected_conv_window, float pad_value) + : padding(padding), + orig_conv_window(orig_conv_window), + expected_conv_window(expected_conv_window), + pad_value(pad_value) {} + + string ToString() const { + return absl::StrFormat( + "padding=%s, orig_conv_window=%s, expected_conv_window=%s, " + "pad_value=%f", + padding, orig_conv_window, expected_conv_window, pad_value); + } + + string padding; + string orig_conv_window; + string expected_conv_window; + float pad_value; +}; + +// ConvInputPaddingTest (and its one associated TEST_P testcase) checks that a +// computation that does +// +// conv(pad(param0, padding=padding), param1), window=orig_conv_window +// +// gets transformed by AlgebraicSimplifier to +// +// conv(param0, param1), window=expected_conv_window +// +// or, if expected_conv_window is the empty string, checks that +// AlgebraicSimplifier does *not* transform the original convolution. +class ConvInputPaddingTest + : public AlgebraicSimplifierTest, + public ::testing::WithParamInterface {}; + +INSTANTIATE_TEST_CASE_P( + ConvInputPaddingTestCases, ConvInputPaddingTest, + ::testing::ValuesIn(std::vector{ + // Merge this edge padding into the conv. + {"0_0x0_0x1_1x2_2", "", "pad=1_1x2_2"}, + // Merge this edge padding with the conv's edge padding. + {"0_0x0_0x1_2x3_4", "pad=10_10x20_20", "pad=11_12x23_24"}, + // Merge this interior-padded kPad with the unpadded conv. The 3x6 + // interior padding gets transformed to 4x7 conv lhs dilation. + {"0_0x0_0x1_2_3x4_5_6", "", "pad=1_2x4_5 lhs_dilate=4x7"}, + // kPad has dilation on one dim, conv has it on the other; merge them. + {"0_0x0_0x0_0_1x0_0_0", "lhs_dilate=1x10", "lhs_dilate=2x10"}, + // kPad has dilation and edge padding on one dim, conv has them on the + // other; merge them. + {"0_0x0_0x0_1_1x0_0_0", "pad=0_0x3_0 lhs_dilate=1x10", + "pad=0_1x3_0 lhs_dilate=2x10"}, + + // Don't transform if the pad value is nonzero. + {"0_0x0_0x1_1x2_2", "", "", /*pad_value=*/1}, + + // We refuse to transform the following because on some dimension, one + // of the kPad and conv has dilation and the other has some sort of + // padding. + {"0_0x0_0x0_0_1x0_0", "pad=1_0x0_0", ""}, + {"0_0x0_0x0_0_1x0_0", "pad=0_1x0_0", ""}, + {"0_0x0_0x0_0_1x0_0", "lhs_dilate=2x1", ""}, + {"0_0x0_0x1_0_0x0_0", "lhs_dilate=2x1", ""}, + {"0_0x0_0x0_1_0x0_0", "lhs_dilate=2x1", ""}, + {"0_0x0_0x0_0_1x0_0", "lhs_dilate=2x1", ""}, + + // We can't merge feature or batch padding into the conv. + {"1_0x0_0x0_0x0_0", "", ""}, + {"0_0x1_0x0_0x0_0", "", ""}, + })); + +TEST_P(ConvInputPaddingTest, DoTest) { + ConvPaddingTestcase testcase = GetParam(); + + // It would be better to put the testcase's ToString into the test name, but + // gUnit has constraints on what can go into test names, and any reasonable + // implementation of ToString() seems to violate them. + SCOPED_TRACE(testcase.ToString()); + + auto builder = HloComputation::Builder(TestName()); + auto* input = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {1024, 128, 100, 100}), // bf01 + "input")); + auto* pad_value = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(testcase.pad_value))); + + PaddingConfig padding_config = + ParsePaddingConfig(testcase.padding).ValueOrDie(); + auto* lhs_pad = builder.AddInstruction(HloInstruction::CreatePad( + ShapeInference::InferPadShape(input->shape(), pad_value->shape(), + padding_config) + .ValueOrDie(), + input, pad_value, padding_config)); + + auto* filter = builder.AddInstruction(HloInstruction::CreateParameter( + 1, + ShapeUtil::MakeShape( + F32, {lhs_pad->shape().dimensions(1), 256, 3, 3}), // io01 + "input")); + + ConvolutionDimensionNumbers dnums = + ParseConvolutionDimensionNumbers("bf01_io01->bf01").ValueOrDie(); + Window window = + ParseWindow(absl::StrCat("size=3x3 ", testcase.orig_conv_window)) + .ValueOrDie(); + builder.AddInstruction(HloInstruction::CreateConvolve( + ShapeInference::InferConvolveShape(lhs_pad->shape(), filter->shape(), + /*feature_group_count=*/1, window, + dnums) + .ValueOrDie(), + lhs_pad, filter, /*feature_group_count=*/1, window, dnums, + DefaultPrecisionConfig(2))); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + if (testcase.expected_conv_window.empty()) { + ASSERT_FALSE(simplifier.Run(module).ValueOrDie()); + } else { + ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); + auto* conv = module->entry_computation()->root_instruction(); + SCOPED_TRACE(module->ToString()); + ASSERT_THAT(conv, op::Convolution(op::Parameter(), op::Parameter())); + EXPECT_EQ(window_util::ToString(conv->window()), + absl::StrCat("size=3x3 ", testcase.expected_conv_window)); + } +} + +// ConvFilterPaddingTest (and its one associated TEST_P) checks that a +// computation that does +// +// conv(param0, pad(param1, padding=padding)), window=orig_conv_window +// +// gets transformed by AlgebraicSimplifier to +// +// conv(param0, param1), window=expected_conv_window +// +// or, if expected_conv_window is the empty string, checks that +// AlgebraicSimplifier does *not* transform the original convolution. +class ConvFilterPaddingTest + : public AlgebraicSimplifierTest, + public ::testing::WithParamInterface {}; + +INSTANTIATE_TEST_CASE_P( + ConvFilterPaddingTestCases, ConvFilterPaddingTest, + ::testing::ValuesIn(std::vector{ + // Can only merge interior padding on the filter's spatial dimensions; + // all + // other paddings (edge padding and interior padding on the channel + // dims) + // should be rejected out of hand. + {"1_0_0x0_0_0x0_0x0_0", "", ""}, + {"0_1_0x0_0_0x0_0x0_0", "", ""}, + {"0_0_1x0_0_0x0_0x0_0", "", ""}, + {"0_0_0x1_0_0x0_0x0_0", "", ""}, + {"0_0_0x0_1_0x0_0x0_0", "", ""}, + {"0_0_0x0_0_1x0_0x0_0", "", ""}, + {"0_0_0x0_0_0x1_0x0_0", "", ""}, + {"0_0_0x0_0_0x0_1x0_0", "", ""}, + {"0_0_0x0_0_0x0_0x1_0", "", ""}, + {"0_0_0x0_0_0x0_0x0_1", "", ""}, + + // Interior padding on channel dims can be merged into the conv, so long + // as the conv and pad don't have interior padding on the same dim. + {"0_0x0_0x0_0_5x0_0", "", "rhs_dilate=6x1"}, + {"0_0x0_0x0_0x0_0_10", "", "rhs_dilate=1x11"}, + {"0_0x0_0x0_0_10x0_0_100", "", "rhs_dilate=11x101"}, + {"0_0x0_0x0_0_1x0_0", "rhs_dilate=1x10", "rhs_dilate=2x10"}, + {"0_0x0_0x0_0x0_0_5", "rhs_dilate=10x1", "rhs_dilate=10x6"}, + + // Can't merge if for a given dim there's interior padding on both the + // pad and conv. + {"0_0x0_0x0_0_1x0_0", "rhs_dilate=2x10", ""}, + {"0_0x0_0x0_0x0_0_5", "rhs_dilate=10x2", ""}, + + // Don't transform if the pad value is nonzero. + {"0_0x0_0x0_0_5x0_0", "", "", /*pad_value=*/1}, + })); + +TEST_P(ConvFilterPaddingTest, DoIt) { + ConvPaddingTestcase testcase = GetParam(); + + // It would be better to put the testcase's ToString into the test name, but + // gUnit has constraints on what can go into test names, and any reasonable + // implementation of ToString() seems to violate them. + SCOPED_TRACE(testcase.ToString()); + + auto builder = HloComputation::Builder(TestName()); + auto* pad_value = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(testcase.pad_value))); + auto* filter = builder.AddInstruction(HloInstruction::CreateParameter( + 1, ShapeUtil::MakeShape(F32, {128, 256, 3, 3}), // io01 + "input")); + PaddingConfig padding_config = + ParsePaddingConfig(testcase.padding).ValueOrDie(); + auto* rhs_pad = builder.AddInstruction(HloInstruction::CreatePad( + ShapeInference::InferPadShape(filter->shape(), pad_value->shape(), + padding_config) + .ValueOrDie(), + filter, pad_value, padding_config)); + + auto* input = builder.AddInstruction(HloInstruction::CreateParameter( + 0, + ShapeUtil::MakeShape( + F32, {1024, rhs_pad->shape().dimensions(0), 100, 100}), // bf01 + "input")); + + ConvolutionDimensionNumbers dnums = + ParseConvolutionDimensionNumbers("bf01_io01->bf01").ValueOrDie(); + Window window = ParseWindow(absl::StrFormat("size=%dx%d %s", + rhs_pad->shape().dimensions(2), + rhs_pad->shape().dimensions(3), + testcase.orig_conv_window)) + .ValueOrDie(); + + // Add a PrecisionConfig and check that AlgebraicSimplifier keeps it in place + // after the transformation. + PrecisionConfig precision_config; + precision_config.add_operand_precision(PrecisionConfig::HIGH); + precision_config.add_operand_precision(PrecisionConfig::HIGHEST); + + builder.AddInstruction(HloInstruction::CreateConvolve( + ShapeInference::InferConvolveShape(input->shape(), rhs_pad->shape(), + /*feature_group_count=*/1, window, + dnums) + .ValueOrDie(), + input, rhs_pad, /*feature_group_count=*/1, window, dnums, + precision_config)); + + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + if (testcase.expected_conv_window.empty()) { + ASSERT_FALSE(simplifier.Run(module).ValueOrDie()); + } else { + ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); + auto* conv = module->entry_computation()->root_instruction(); + SCOPED_TRACE(module->ToString()); + ASSERT_THAT(conv, op::Convolution(op::Parameter(), op::Parameter())); + EXPECT_EQ(window_util::ToString(conv->window()), + absl::StrFormat("size=%dx%d %s", + conv->operand(1)->shape().dimensions(2), + conv->operand(1)->shape().dimensions(3), + testcase.expected_conv_window)); + EXPECT_THAT(Cast(conv) + ->precision_config() + .operand_precision(), + ElementsAre(PrecisionConfig::HIGH, PrecisionConfig::HIGHEST)); + } } TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) { @@ -2115,7 +2515,7 @@ TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) { auto out_dims = in_dims; out_dims[in_channel_idx] = options.f_output_channels; - auto make_shape = [](tensorflow::gtl::ArraySlice dims, + auto make_shape = [](absl::Span dims, bool minor_to_major_layout) { if (minor_to_major_layout) { return ShapeUtil::MakeShapeWithLayout(F32, dims, {0, 1, 2, 3}); @@ -2132,8 +2532,9 @@ TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) { HloInstruction* filter = b.AddInstruction(HloInstruction::CreateParameter(1, f_shape, "filter")); - b.AddInstruction(HloInstruction::CreateConvolve(out_shape, input, filter, - window, dnums)); + b.AddInstruction(HloInstruction::CreateConvolve( + out_shape, input, filter, + /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); // TODO(b/80488902): verify this module. auto module = HloTestBase::CreateNewModule(); @@ -2511,7 +2912,8 @@ TEST_F(AlgebraicSimplifierTest, IteratorInvalidation) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - builder.AddInstruction(HloInstruction::CreateDot(r1f32, x, y, dot_dnums)); + builder.AddInstruction(HloInstruction::CreateDot(r1f32, x, y, dot_dnums, + DefaultPrecisionConfig(2))); std::unique_ptr dot_computation(builder.Build()); HloComputation::Builder call_builder(TestName() + ".Call"); @@ -2534,9 +2936,9 @@ TEST_F(AlgebraicSimplifierTest, ConstantTupleBecomesTupleOfConstants) { HloComputation::Builder builder(TestName()); const float constant_scalar = 7.3f; std::initializer_list constant_vector = {1.1f, 2.0f, 3.3f}; - std::unique_ptr value = LiteralUtil::MakeTuple( - {LiteralUtil::CreateR0(constant_scalar).get(), - LiteralUtil::CreateR1(constant_vector).get()}); + Literal elements[] = {LiteralUtil::CreateR0(constant_scalar), + LiteralUtil::CreateR1(constant_vector)}; + Literal value = LiteralUtil::MakeTuple({&elements[0], &elements[1]}); builder.AddInstruction(HloInstruction::CreateConstant(std::move(value))); auto computation = module().AddEntryComputation(builder.Build()); @@ -2653,6 +3055,47 @@ TEST_F(AlgebraicSimplifierTest, MergeBroadcasts2) { EXPECT_THAT(root->dimensions(), ElementsAre(1, 3)); } +// Test that a broadcast of an iota can be merged to one iota. +TEST_F(AlgebraicSimplifierTest, MergeBroadcastAndIota) { + HloComputation::Builder builder(TestName()); + Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2}); + HloInstruction* iota = + builder.AddInstruction(HloInstruction::CreateIota(r2f32, 1)); + Shape r3f32 = ShapeUtil::MakeShape(F32, {2, 2, 2}); + builder.AddInstruction(HloInstruction::CreateBroadcast(r3f32, iota, {0, 2})); + + auto computation = module().AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kBroadcast); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + root = computation->root_instruction(); + EXPECT_THAT(root, op::Iota()); + EXPECT_EQ(Cast(root)->iota_dimension(), 2); +} + +// Test that a broadcast of an iota can be merged to one iota. +TEST_F(AlgebraicSimplifierTest, MergeBroadcastAndIota2) { + HloComputation::Builder builder(TestName()); + Shape r3f32 = ShapeUtil::MakeShape(F32, {2, 5, 3}); + HloInstruction* iota = + builder.AddInstruction(HloInstruction::CreateIota(r3f32, 1)); + Shape r4f32 = ShapeUtil::MakeShape(F32, {4, 2, 5, 3}); + builder.AddInstruction( + HloInstruction::CreateBroadcast(r4f32, iota, {1, 2, 3})); + + auto computation = module().AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kBroadcast); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + root = computation->root_instruction(); + EXPECT_THAT(root, op::Iota()); + EXPECT_EQ(Cast(root)->iota_dimension(), 2); +} + struct PadReduceWindowEffectiveBroadcastCase { std::vector input_spatials; std::vector symmetric_pad_spatials; @@ -2686,8 +3129,8 @@ TEST_P(PadReduceWindowEffectiveBroadcastTest, DoIt) { // a and b are parallel bounds we can either turn into a B F S0 S1 or // `B S0 S1 F` kind of pattern. - auto decorate_spatials = [¶m](tensorflow::gtl::ArraySlice spatials, - int64 a, int64 b) { + auto decorate_spatials = [¶m](absl::Span spatials, int64 a, + int64 b) { std::vector result; if (param.prepend_a) { result.push_back(a); @@ -2794,17 +3237,18 @@ INSTANTIATE_TEST_CASE_P( class DotStrengthReductionTest : public AlgebraicSimplifierTest, public ::testing::WithParamInterface< - ::testing::tuple> {}; + ::testing::tuple> {}; TEST_P(DotStrengthReductionTest, DotStrengthReduction) { int m, k, n; bool transpose_lhs, transpose_rhs; - std::tie(m, k, n, transpose_lhs, transpose_rhs) = GetParam(); + PrimitiveType element_type; + std::tie(m, k, n, transpose_lhs, transpose_rhs, element_type) = GetParam(); - Shape dot_shape = ShapeUtil::MakeShape(F32, {m, n}); - Shape lhs_shape = ShapeUtil::MakeShape(F32, {m, k}); - Shape transposed_lhs_shape = ShapeUtil::MakeShape(F32, {k, m}); - Shape rhs_shape = ShapeUtil::MakeShape(F32, {k, n}); - Shape transposed_rhs_shape = ShapeUtil::MakeShape(F32, {n, k}); + Shape dot_shape = ShapeUtil::MakeShape(element_type, {m, n}); + Shape lhs_shape = ShapeUtil::MakeShape(element_type, {m, k}); + Shape transposed_lhs_shape = ShapeUtil::MakeShape(element_type, {k, m}); + Shape rhs_shape = ShapeUtil::MakeShape(element_type, {k, n}); + Shape transposed_rhs_shape = ShapeUtil::MakeShape(element_type, {n, k}); HloComputation::Builder builder(TestName()); auto lhs = builder.AddInstruction(HloInstruction::CreateParameter( @@ -2822,8 +3266,8 @@ TEST_P(DotStrengthReductionTest, DotStrengthReduction) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - builder.AddInstruction( - HloInstruction::CreateDot(dot_shape, lhs, rhs, dot_dnums)); + builder.AddInstruction(HloInstruction::CreateDot( + dot_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2))); auto computation = module().AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); @@ -2846,7 +3290,7 @@ INSTANTIATE_TEST_CASE_P( DotStrengthReductionTestInstantiation, DotStrengthReductionTest, ::testing::Combine(::testing::Values(1, 2), ::testing::Values(1, 2), ::testing::Values(1, 2), ::testing::Bool(), - ::testing::Bool())); + ::testing::Bool(), ::testing::Values(F32, BF16))); struct DotOfConcatTestSpec { int64 m; @@ -2856,12 +3300,7 @@ struct DotOfConcatTestSpec { class DotOfConcatSimplificationTest : public HloVerifiedTestBase, - public ::testing::WithParamInterface { - public: - DotOfConcatSimplificationTest() - : HloVerifiedTestBase(/*layout_sensitive=*/false, - /*allow_mixed_precision=*/false) {} -}; + public ::testing::WithParamInterface {}; // Test that we transform // dot(const, concat(A, B, C)) @@ -2903,8 +3342,8 @@ TEST_P(DotOfConcatSimplificationTest, ConstantLHS) { dot_dnums.add_rhs_contracting_dimensions(0); Shape dot_shape = ShapeUtil::MakeShape(F32, {spec.m, spec.n}); - builder.AddInstruction( - HloInstruction::CreateDot(dot_shape, lhs, rhs, dot_dnums)); + builder.AddInstruction(HloInstruction::CreateDot( + dot_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2))); auto computation = module().AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, @@ -2967,8 +3406,8 @@ TEST_P(DotOfConcatSimplificationTest, ConstantRHS) { dot_dnums.add_rhs_contracting_dimensions(0); Shape dot_shape = ShapeUtil::MakeShape(F32, {spec.m, spec.n}); - builder.AddInstruction( - HloInstruction::CreateDot(dot_shape, lhs, rhs, dot_dnums)); + builder.AddInstruction(HloInstruction::CreateDot( + dot_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2))); auto computation = module().AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, @@ -3034,12 +3473,7 @@ struct DotOfGatherTestSpec { class DotOfGatherSimplificationTest : public HloVerifiedTestBase, - public ::testing::WithParamInterface { - public: - DotOfGatherSimplificationTest() - : HloVerifiedTestBase(/*layout_sensitive=*/false, - /*allow_mixed_precision=*/false) {} -}; + public ::testing::WithParamInterface {}; // input: dot(DS(ctA), ctB)) // where DS(ctA) = DS({M x K}, {s, 0}, {1, K}) and ctB = {K x N}. @@ -3090,8 +3524,8 @@ TEST_P(DotOfGatherSimplificationTest, ConstantRHS) { int64 dot_row_size = 1; int64 dot_col_size = spec.n; Shape dot_shape = ShapeUtil::MakeShape(F32, {dot_row_size, dot_col_size}); - builder.AddInstruction( - HloInstruction::CreateDot(dot_shape, ds, rhs, dot_dnums)); + builder.AddInstruction(HloInstruction::CreateDot( + dot_shape, ds, rhs, dot_dnums, DefaultPrecisionConfig(2))); auto computation = module().AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, @@ -3160,8 +3594,8 @@ TEST_P(DotOfGatherSimplificationTest, ConstantLHS) { int64 dot_row_size = spec.m; int64 dot_col_size = 1; Shape dot_shape = ShapeUtil::MakeShape(F32, {dot_row_size, dot_col_size}); - builder.AddInstruction( - HloInstruction::CreateDot(dot_shape, lhs, ds, dot_dnums)); + builder.AddInstruction(HloInstruction::CreateDot( + dot_shape, lhs, ds, dot_dnums, DefaultPrecisionConfig(2))); auto computation = module().AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, diff --git a/tensorflow/compiler/xla/service/allocation_tracker.cc b/tensorflow/compiler/xla/service/allocation_tracker.cc index 5115a14df02a780cd51bf8c96825d2f390cf6ec8..ef5e211646e7b0b66b8e6c09948be58063422943 100644 --- a/tensorflow/compiler/xla/service/allocation_tracker.cc +++ b/tensorflow/compiler/xla/service/allocation_tracker.cc @@ -69,8 +69,7 @@ StatusOr AllocationTracker::RegisterInternal( return InvalidArgument( "AllocationTracker for platform %s cannot register buffer from " "platform %s", - backend_->platform()->Name().c_str(), - shaped_buffer.platform()->Name().c_str()); + backend_->platform()->Name(), shaped_buffer.platform()->Name()); } } @@ -125,7 +124,7 @@ Status AllocationTracker::Unregister(const GlobalDataHandle& data) { // "handle does not exist". auto it = handle_to_shaped_buffers_.find(data.handle()); if (it == handle_to_shaped_buffers_.end()) { - return NotFound("no allocation record for global data handle: %lld", + return NotFound("no allocation record for global data handle: %d", data.handle()); } for (auto& shaped_buffer : it->second) { @@ -144,7 +143,7 @@ StatusOr> AllocationTracker::DeconstructTuple( // the same for all buffers across replicas. const ShapedBuffer* shaped_buffer = replicated_buffers[0]; if (!ShapeUtil::IsTuple(shaped_buffer->on_host_shape())) { - return InvalidArgument("global data handle %lld is not a tuple", + return InvalidArgument("global data handle %d is not a tuple", data.handle()); } // If the on-host representation is a tuple, then the on-device one should be @@ -177,13 +176,13 @@ StatusOr> AllocationTracker::DeconstructTuple( } StatusOr> AllocationTracker::Resolve( - const GlobalDataHandle& data) { + const GlobalDataHandle& data) const { tensorflow::mutex_lock lock(mutex_); return AllocationTracker::ResolveInternal(data); } StatusOr AllocationTracker::ResolveForReplica( - const GlobalDataHandle& data, int replica_id) { + const GlobalDataHandle& data, int replica_id) const { tensorflow::mutex_lock lock(mutex_); TF_ASSIGN_OR_RETURN(std::vector replicated_buffers, ResolveInternal(data)); @@ -197,18 +196,18 @@ StatusOr AllocationTracker::ResolveForReplica( } StatusOr> AllocationTracker::ResolveInternal( - const GlobalDataHandle& data) { + const GlobalDataHandle& data) const { VLOG(2) << "resolve:" << data.handle(); auto it = handle_to_shaped_buffers_.find(data.handle()); if (it == handle_to_shaped_buffers_.end()) { - return NotFound("no allocation record for global data handle: %lld", + return NotFound("no allocation record for global data handle: %d", data.handle()); } std::vector replicated_buffers; for (const auto& shaped_buffer : it->second) { if (shaped_buffer == nullptr) { - return InvalidArgument( - "global data handle %lld was previously deallocated", data.handle()); + return InvalidArgument("global data handle %d was previously deallocated", + data.handle()); } replicated_buffers.push_back(shaped_buffer.get()); } diff --git a/tensorflow/compiler/xla/service/allocation_tracker.h b/tensorflow/compiler/xla/service/allocation_tracker.h index a7d8927cf7e90d764ff8046df16c71922b11478e..98d1a302a9f66f4a00e05d62837a79133e222687 100644 --- a/tensorflow/compiler/xla/service/allocation_tracker.h +++ b/tensorflow/compiler/xla/service/allocation_tracker.h @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" @@ -64,13 +65,13 @@ class AllocationTracker { // replica, or provide an error status to say whether any of those buffers // were not found (or found, but found deallocated). StatusOr> Resolve( - const GlobalDataHandle& data); + const GlobalDataHandle& data) const; // Resolves a handle from an XLA client and replica id to a shaped buffer, or // provide an error status to say whether it was not found (or found, but // found deallocated). StatusOr ResolveForReplica(const GlobalDataHandle& data, - int replica_id); + int replica_id) const; private: // Data structure encapsulating single memory allocation on the device. @@ -86,7 +87,7 @@ class AllocationTracker { // Internal helper which resolves the given GlobalDataHandle to a // list of ScopedShapedBuffers. StatusOr> ResolveInternal( - const GlobalDataHandle& data) EXCLUSIVE_LOCKS_REQUIRED(mutex_); + const GlobalDataHandle& data) const EXCLUSIVE_LOCKS_REQUIRED(mutex_); // Internal helper which registers a vector of shaped buffers, one per // replica. ShapedBufferTy is either ScopedShapedBuffer or ShapedBuffer. If @@ -110,9 +111,9 @@ class AllocationTracker { // A map from device memory opaque value to allocation. One such map is // maintained per device ordinal. - using AllocationMap = tensorflow::gtl::FlatMap; + using AllocationMap = absl::flat_hash_map; - tensorflow::mutex mutex_; + mutable tensorflow::mutex mutex_; // Backend to use with this tracker. The backend supplies the memory allocator // to use when deallocating memory. @@ -123,10 +124,7 @@ class AllocationTracker { int64 next_handle_ GUARDED_BY(mutex_); // A map from device ordinal to AllocationMap. - // - // This is not a TF FlatMap because (currently) FlatMap (and therefore - // AllocationMap) is not movable. - std::unordered_map opaque_to_allocation_map_ + absl::flat_hash_map opaque_to_allocation_map_ GUARDED_BY(mutex_); // A map from data handle to a vector of shaped buffers that represent the @@ -146,7 +144,7 @@ class AllocationTracker { // non-owning "view" into a tuple's sub-buffers. The sub-buffers are then // free'd when both the view *and* the original tuple are Unregistered. This // refcounting is managed in opaque_to_allocation_map_. - tensorflow::gtl::FlatMap>> + absl::flat_hash_map>> handle_to_shaped_buffers_ GUARDED_BY(mutex_); TF_DISALLOW_COPY_AND_ASSIGN(AllocationTracker); diff --git a/tensorflow/compiler/xla/service/backend.cc b/tensorflow/compiler/xla/service/backend.cc index 841d0fa85bb9c548cd737e21bb988886f43378bd..5c180cbdd492031e133b81149f0f4698619b7788 100644 --- a/tensorflow/compiler/xla/service/backend.cc +++ b/tensorflow/compiler/xla/service/backend.cc @@ -112,11 +112,11 @@ StatusOr Backend::BorrowStream(se::StreamExecutor* executor) { return stream_pools_.at(executor).BorrowStream(executor); } -Backend::Backend( - se::Platform* platform, Compiler* compiler, - tensorflow::gtl::ArraySlice stream_executors, - TransferManager* transfer_manager, ComputationPlacer* computation_placer, - int intra_op_parallelism_threads) +Backend::Backend(se::Platform* platform, Compiler* compiler, + absl::Span stream_executors, + TransferManager* transfer_manager, + ComputationPlacer* computation_placer, + int intra_op_parallelism_threads) : platform_(platform), compiler_(compiler), transfer_manager_(transfer_manager), @@ -177,7 +177,7 @@ StatusOr Backend::stream_executor( } } return InvalidArgument("device %s not supported by XLA service", - device_name(device_ordinal).c_str()); + device_name(device_ordinal)); } StatusOr Backend::devices_equivalent(int device_ordinal_a, diff --git a/tensorflow/compiler/xla/service/backend.h b/tensorflow/compiler/xla/service/backend.h index 4a6a78daf07256684402f448725b219d5983ed9e..a2dafbe803f8bd5f23e4e9f3f6d3e6f744c9fab9 100644 --- a/tensorflow/compiler/xla/service/backend.h +++ b/tensorflow/compiler/xla/service/backend.h @@ -22,6 +22,7 @@ limitations under the License. #include #include "absl/strings/str_cat.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/computation_placer.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" @@ -29,7 +30,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/transfer_manager.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/thread_annotations.h" @@ -149,7 +149,7 @@ class Backend { private: struct EigenThreadPoolWrapper; Backend(se::Platform* platform, Compiler* compiler, - tensorflow::gtl::ArraySlice stream_executors, + absl::Span stream_executors, TransferManager* transfer_manager, ComputationPlacer* computation_placer, int intra_op_parallelism_threads); diff --git a/tensorflow/compiler/xla/service/batch_dot_simplification.cc b/tensorflow/compiler/xla/service/batch_dot_simplification.cc index a16b85a0a5e3f72f54e9733bb974b01377e0c358..eda026ac5685dc469a6230094eb28b3618e36400 100644 --- a/tensorflow/compiler/xla/service/batch_dot_simplification.cc +++ b/tensorflow/compiler/xla/service/batch_dot_simplification.cc @@ -63,8 +63,8 @@ BatchDotSimplification::ElideDegenerateBatchDimensionFromBatchDot( new_dim_numbers.rhs_contracting_dimensions(0) - degenerate_dims.size()); TF_ASSIGN_OR_RETURN(HloInstruction * new_dot, - MakeDotHlo(new_lhs, new_rhs, new_dim_numbers)); - new_dot->set_precision_config(batch_dot->precision_config()); + MakeDotHlo(new_lhs, new_rhs, new_dim_numbers, + batch_dot->precision_config())); TF_ASSIGN_OR_RETURN(HloInstruction * new_dot_reshaped, MakeReshapeHlo(batch_dot->shape(), new_dot)); diff --git a/tensorflow/compiler/xla/service/batch_dot_simplification.h b/tensorflow/compiler/xla/service/batch_dot_simplification.h index 79d37f08d3553321ebbabc44c8f2488b194954d5..5b625bf3b98b060531532f07de343f7ca4f09ac9 100644 --- a/tensorflow/compiler/xla/service/batch_dot_simplification.h +++ b/tensorflow/compiler/xla/service/batch_dot_simplification.h @@ -25,7 +25,7 @@ namespace xla { // Normally these would live in the algebraic simplifier, but we want to run // this to fixpoint (this pass reaches fixed point in one execution) before we // run the DotDecomposer. -class BatchDotSimplification : public HloPassInterface { +class BatchDotSimplification : public HloModulePass { public: StatusOr Run(HloModule* module) override; absl::string_view name() const override; diff --git a/tensorflow/compiler/xla/service/batch_dot_simplification_test.cc b/tensorflow/compiler/xla/service/batch_dot_simplification_test.cc index b342acb0259498c2255f55da1cb7a3da700bdca4..38f1a5d3a645f98220ec445bb9bbdf2b9b842109 100644 --- a/tensorflow/compiler/xla/service/batch_dot_simplification_test.cc +++ b/tensorflow/compiler/xla/service/batch_dot_simplification_test.cc @@ -24,12 +24,7 @@ namespace { namespace op = xla::testing::opcode_matchers; -class BatchDotSimplificationTest : public HloVerifiedTestBase { - public: - BatchDotSimplificationTest() - : HloVerifiedTestBase(/*layout_sensitive=*/false, - /*allow_mixed_precision=*/false) {} -}; +class BatchDotSimplificationTest : public HloVerifiedTestBase {}; TEST_F(BatchDotSimplificationTest, ElideSingleDegenerateBatchDotDim_VectorVector) { diff --git a/tensorflow/compiler/xla/service/batchnorm_expander.cc b/tensorflow/compiler/xla/service/batchnorm_expander.cc index 01931b2d02c2771b85474ca0cb6a1a92b3e9ffe7..f70f6ddfec69c0113a1afe2073a2392098f49456 100644 --- a/tensorflow/compiler/xla/service/batchnorm_expander.cc +++ b/tensorflow/compiler/xla/service/batchnorm_expander.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include "absl/types/optional.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" @@ -34,8 +35,6 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -205,11 +204,11 @@ Status BatchNormExpanderVisitor::HandleBatchNormTraining( const Shape feature_shape = scale->shape(); auto zero_literal = LiteralUtil::CreateR0(0.0f); - TF_ASSIGN_OR_RETURN(zero_literal, zero_literal->Convert(ptype)); + TF_ASSIGN_OR_RETURN(zero_literal, zero_literal.Convert(ptype)); auto zero = add(HloInstruction::CreateConstant(std::move(zero_literal))); auto epsilon_literal = LiteralUtil::CreateR0(batch_norm->epsilon()); - TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype)); + TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal.Convert(ptype)); auto epsilon = add(HloInstruction::CreateBroadcast( operand_shape, add(HloInstruction::CreateConstant(std::move(epsilon_literal))), {})); @@ -331,7 +330,7 @@ Status BatchNormExpanderVisitor::HandleBatchNormInference( const Shape feature_shape = scale->shape(); auto epsilon_literal = LiteralUtil::CreateR0(batch_norm->epsilon()); - TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype)); + TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal.Convert(ptype)); auto epsilon = computation_->AddInstruction(HloInstruction::CreateBroadcast( operand_shape, computation_->AddInstruction( @@ -464,11 +463,11 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad( const int64 elements_per_feature_int64 = size_in_elements / feature_count; auto zero_literal = LiteralUtil::CreateR0(0.0f); - TF_ASSIGN_OR_RETURN(zero_literal, zero_literal->Convert(ptype)); + TF_ASSIGN_OR_RETURN(zero_literal, zero_literal.Convert(ptype)); auto zero = add(HloInstruction::CreateConstant(std::move(zero_literal))); auto epsilon_literal = LiteralUtil::CreateR0(batch_norm->epsilon()); - TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype)); + TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal.Convert(ptype)); auto epsilon_scalar = add(HloInstruction::CreateConstant(std::move(epsilon_literal))); auto epsilon_activation = add( @@ -560,7 +559,7 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad( auto elements_per_feature_literal = LiteralUtil::CreateR0(elements_per_feature_int64); TF_ASSIGN_OR_RETURN(elements_per_feature_literal, - elements_per_feature_literal->Convert(ptype)); + elements_per_feature_literal.Convert(ptype)); auto elements_per_feature = add( HloInstruction::CreateConstant(std::move(elements_per_feature_literal))); auto i1 = add_binary(activation_shape, HloOpcode::kMultiply, grad_output, diff --git a/tensorflow/compiler/xla/service/batchnorm_expander.h b/tensorflow/compiler/xla/service/batchnorm_expander.h index 76e32174f3ee7d319df6f1f465e19d265d5330f2..147f3ae7b6d4ed0d4dadfb136e1e0f0bf3ae90c6 100644 --- a/tensorflow/compiler/xla/service/batchnorm_expander.h +++ b/tensorflow/compiler/xla/service/batchnorm_expander.h @@ -26,7 +26,7 @@ namespace xla { // A pass which rewrites batch norm operations into more operations. Breaking a // big operation into smaller operations helps leverage our generic fusion // logic. -class BatchNormExpander : public HloPassInterface { +class BatchNormExpander : public HloModulePass { public: // When use_fusion is set, a multi-output fusion node is created. BatchNormExpander(bool rewrite_training_op = false, diff --git a/tensorflow/compiler/xla/service/batchnorm_expander_test.cc b/tensorflow/compiler/xla/service/batchnorm_expander_test.cc index aba0d9bb5b977d89656580df46838eefb8cd6662..f7ac8f5482908af104554a1cf812370b9098cda7 100644 --- a/tensorflow/compiler/xla/service/batchnorm_expander_test.cc +++ b/tensorflow/compiler/xla/service/batchnorm_expander_test.cc @@ -29,14 +29,14 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_pass_fix.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { namespace { -using BatchNormExpanderTest = HloTestBase; +using BatchNormExpanderTest = HloVerifiedTestBase; // Test that we expand BatchNormTraining. TEST_F(BatchNormExpanderTest, BatchNormTraining) { @@ -66,7 +66,7 @@ TEST_F(BatchNormExpanderTest, BatchNormTraining) { BatchNormExpander rewriter(/*rewrite_training_op=*/true, /*rewrite_inference_op=*/true, /*rewrite_grad_op=*/true); - ASSERT_TRUE(rewriter.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(rewriter.Run(module).ValueOrDie()); root = computation->root_instruction(); // Make sure this operation is expanded. EXPECT_EQ(root->opcode(), HloOpcode::kTuple); @@ -108,7 +108,7 @@ TEST_F(BatchNormExpanderTest, BatchNormGrad) { BatchNormExpander rewriter(/*rewrite_training_op=*/true, /*rewrite_inference_op=*/true, /*rewrite_grad_op=*/true); - ASSERT_TRUE(rewriter.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(rewriter.Run(module).ValueOrDie()); root = computation->root_instruction(); // Make sure this operation is expanded. EXPECT_EQ(root->opcode(), HloOpcode::kTuple); @@ -126,13 +126,13 @@ ENTRY entry { epsilon=0.001, feature_index=1, sharding={maximal device=1} })"; - TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(module_str)); + ParseAndVerifyModule(module_str); BatchNormExpander rewriter(/*rewrite_training_op=*/true, /*rewrite_inference_op=*/true, /*rewrite_grad_op=*/true); - ASSERT_TRUE(rewriter.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(rewriter.Run(&module()).ValueOrDie()); - for (auto* instruction : module->entry_computation()->instructions()) { + for (auto* instruction : module().entry_computation()->instructions()) { if (instruction->opcode() == HloOpcode::kParameter) { continue; } diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc index 1b8b2d204503576c3fcb02f6d5b37f2db45e1768..d63287539dfde5bb4890ab8303ef2205133d8125 100644 --- a/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc +++ b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc @@ -15,12 +15,12 @@ limitations under the License. #include "tensorflow/compiler/xla/service/bfloat16_conversion_folding.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding.h b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.h index 5dcd31b83d24f836d31f44181f39cb8371ca1033..cb3d12f0bfd0e502136ce39660e091dc1c3879be 100644 --- a/tensorflow/compiler/xla/service/bfloat16_conversion_folding.h +++ b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.h @@ -31,7 +31,7 @@ namespace xla { // optimization pipeline followed by a DCE pass. If other passes are needed // after this pass, run BFloat16MixedPrecisionRemoval first to undo some of the // changed made by this pass. -class BFloat16ConversionFolding : public HloPassInterface { +class BFloat16ConversionFolding : public HloModulePass { public: explicit BFloat16ConversionFolding(const BFloat16Support* bfloat16_support) : bfloat16_support_(bfloat16_support) {} diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc b/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc index 6363a21c3bafe8353a6ebfde405bb7a3736c2074..5f93740887aa7e61458990992fe0573883ff056d 100644 --- a/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { @@ -65,8 +65,12 @@ class TestBFloat16Support : public BFloat16Support { } }; -class BFloat16ConversionFoldingTest : public HloTestBase { +class BFloat16ConversionFoldingTest : public HloVerifiedTestBase { protected: + BFloat16ConversionFoldingTest() + : HloVerifiedTestBase(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/true) {} + bool FoldConversions(HloModule* module) { TestBFloat16Support bfloat16_support_; BFloat16ConversionFolding fold(&bfloat16_support_); @@ -102,7 +106,7 @@ TEST_F(BFloat16ConversionFoldingTest, FoldIfSupported) { auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(FoldConversions(module.get())); + EXPECT_TRUE(FoldConversions(module)); EXPECT_EQ(computation->root_instruction(), add1); EXPECT_EQ(add0->shape().element_type(), BF16); @@ -137,7 +141,7 @@ TEST_F(BFloat16ConversionFoldingTest, DoNotFoldIfUnsupported) { auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_FALSE(FoldConversions(module.get())); + EXPECT_FALSE(FoldConversions(module)); EXPECT_EQ(computation->root_instruction(), convert2); EXPECT_EQ(mul0->shape().element_type(), F32); @@ -172,7 +176,7 @@ TEST_F(BFloat16ConversionFoldingTest, DoNotFoldUnsupportedMixedPrecision) { auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_FALSE(FoldConversions(module.get())); + EXPECT_FALSE(FoldConversions(module)); EXPECT_EQ(computation->root_instruction(), convert2); EXPECT_EQ(sub0->shape().element_type(), F32); @@ -202,7 +206,7 @@ TEST_F(BFloat16ConversionFoldingTest, DoNotFoldTuple) { auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_FALSE(FoldConversions(module.get())); + EXPECT_FALSE(FoldConversions(module)); EXPECT_EQ(computation->root_instruction(), convert1); EXPECT_EQ(gte->shape().element_type(), F32); @@ -248,7 +252,7 @@ TEST_F(BFloat16ConversionFoldingTest, FoldCrossReplicaSumTupleOutput) { auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(FoldConversions(module.get())); + EXPECT_TRUE(FoldConversions(module)); EXPECT_EQ(computation->root_instruction(), tuple); EXPECT_EQ(tuple->operand(0), gte_a); diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization.cc b/tensorflow/compiler/xla/service/bfloat16_normalization.cc index 32573ed3555204c059d092ef65b18b38b19f9ea5..1251f0258f5d43a490ad654f519fee9076590453 100644 --- a/tensorflow/compiler/xla/service/bfloat16_normalization.cc +++ b/tensorflow/compiler/xla/service/bfloat16_normalization.cc @@ -15,13 +15,13 @@ limitations under the License. #include "tensorflow/compiler/xla/service/bfloat16_normalization.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -69,8 +69,7 @@ class BFloat16NormalizationVisitor : public DfsHloVisitorWithDefault { // Inserts conversion HLOs to replace the called computations' BF16 // operands/outputs to F32. Status ConvertCalledComputations( - HloInstruction* hlo, - tensorflow::gtl::ArraySlice bf16_called_comps); + HloInstruction* hlo, absl::Span bf16_called_comps); HloComputation* computation_; const BFloat16Support* bfloat16_support_; @@ -114,8 +113,7 @@ Status BFloat16NormalizationVisitor::InsertConvertBeforeOperand( } Status BFloat16NormalizationVisitor::ConvertCalledComputations( - HloInstruction* hlo, - tensorflow::gtl::ArraySlice bf16_called_comps) { + HloInstruction* hlo, absl::Span bf16_called_comps) { std::map cloned_computations; for (auto& comp : bf16_called_comps) { auto cloned = comp->parent()->AddEmbeddedComputation(comp->Clone()); @@ -233,6 +231,10 @@ Status BFloat16NormalizationVisitor::HandleMultipleOutputs( for (auto* user : materialized_users) { TF_RETURN_IF_ERROR(hlo->ReplaceUseWith(user, tuple)); } + bool is_root = computation_->root_instruction() == hlo; + if (is_root) { + computation_->set_root_instruction(tuple); + } *tuple->mutable_shape() = original_shape; return Status::OK(); } @@ -359,6 +361,7 @@ Status BFloat16NormalizationVisitor::DefaultAction(HloInstruction* hlo) { hlo->opcode() == HloOpcode::kConditional) { return Status::OK(); } + // TODO(b/112040122): Correctly normalize variadic reduce. if ((hlo->opcode() == HloOpcode::kSort || hlo->opcode() == HloOpcode::kCrossReplicaSum) && ShapeUtil::IsTuple(hlo->shape())) { diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization.h b/tensorflow/compiler/xla/service/bfloat16_normalization.h index 30b6346312790f0a199f96f1956ba9ce3e617f72..f48e925823cf02bf4351b9bc7741123f5b1cd06f 100644 --- a/tensorflow/compiler/xla/service/bfloat16_normalization.h +++ b/tensorflow/compiler/xla/service/bfloat16_normalization.h @@ -25,7 +25,7 @@ namespace xla { // A pass which adds F32 <-> BF16 conversions for HLO instructions that do not // support BF16 input/output or mixed precision, according to the passed-in // backend-specific BF16 support rules. -class BFloat16Normalization : public HloPassInterface { +class BFloat16Normalization : public HloModulePass { public: explicit BFloat16Normalization(const BFloat16Support* bfloat16_support) : bfloat16_support_(bfloat16_support) {} @@ -48,7 +48,7 @@ class BFloat16Normalization : public HloPassInterface { // use mixed precision; it removes mixed precision even if the backend supports // it. This pass is used to make the HLO module valid for other HLO passes which // do not support mixed precision. -class BFloat16MixedPrecisionRemoval : public HloPassInterface { +class BFloat16MixedPrecisionRemoval : public HloModulePass { public: BFloat16MixedPrecisionRemoval() {} diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc index b08705d4c2b644fe1a7ba9994876fd6397f8a5df..cb075a5e38a5ea9db2ceb432b2b59f8db5e2e640 100644 --- a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc @@ -23,7 +23,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { @@ -68,8 +68,12 @@ class TestBFloat16Support : public BFloat16Support { } }; -class BFloat16NormalizationTest : public HloTestBase { +class BFloat16NormalizationTest : public HloVerifiedTestBase { protected: + BFloat16NormalizationTest() + : HloVerifiedTestBase(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/true) {} + bool Normalize(HloModule* module) { TestBFloat16Support bfloat16_support_; BFloat16Normalization normalization(&bfloat16_support_); @@ -105,7 +109,7 @@ TEST_F(BFloat16NormalizationTest, NoopIfSupported) { auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_FALSE(Normalize(module.get())); + EXPECT_FALSE(Normalize(module)); EXPECT_EQ(computation->root_instruction(), add1); EXPECT_EQ(add0->shape().element_type(), BF16); @@ -133,7 +137,7 @@ TEST_F(BFloat16NormalizationTest, ResolveIfUnsupportedBF16) { auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(Normalize(module.get())); + EXPECT_TRUE(Normalize(module)); EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kConvert); EXPECT_EQ(computation->root_instruction()->operand(0), mul1); @@ -163,7 +167,7 @@ TEST_F(BFloat16NormalizationTest, ResolveUnsupportedMixedPrecisionSubtraction) { auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(Normalize(module.get())); + EXPECT_TRUE(Normalize(module)); EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kConvert); EXPECT_EQ(computation->root_instruction()->operand(0), sub1); @@ -201,7 +205,7 @@ TEST_F(BFloat16NormalizationTest, ResolveUnsupportedMixedPrecisionReduce) { auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(Normalize(module.get())); + EXPECT_TRUE(Normalize(module)); EXPECT_EQ(computation->root_instruction(), reduce); EXPECT_EQ(reduce->called_computations().size(), 1); @@ -259,7 +263,7 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleCrossReplicaSum) { auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(Normalize(module.get())); + EXPECT_TRUE(Normalize(module)); EXPECT_EQ(computation->root_instruction(), gte); EXPECT_EQ(gte->shape().element_type(), BF16); @@ -280,13 +284,13 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleSort) { HloInstruction::CreateParameter(1, s32_shape, "value")); HloInstruction* sort = builder.AddInstruction(HloInstruction::CreateSort( - ShapeUtil::MakeTupleShape({bf16_shape, s32_shape}), 0, key, value)); + ShapeUtil::MakeTupleShape({bf16_shape, s32_shape}), 0, key, {value})); HloInstruction* gte = builder.AddInstruction( HloInstruction::CreateGetTupleElement(bf16_shape, sort, 0)); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(Normalize(module.get())); + EXPECT_TRUE(Normalize(module)); EXPECT_EQ(computation->root_instruction(), gte); EXPECT_EQ(gte->shape().element_type(), BF16); @@ -294,6 +298,30 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleSort) { EXPECT_EQ(ShapeUtil::GetSubshape(sort->shape(), {0}).element_type(), F32); } +TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleSortRoot) { + auto module = CreateNewModule(); + auto builder = HloComputation::Builder(TestName()); + Shape f32_shape = ShapeUtil::MakeShape(F32, {1024}); + Shape bf16_shape = ShapeUtil::MakeShape(BF16, {1024}); + + HloInstruction* key = builder.AddInstruction( + HloInstruction::CreateParameter(0, f32_shape, "key")); + HloInstruction* value = builder.AddInstruction( + HloInstruction::CreateParameter(1, bf16_shape, "value")); + + HloInstruction* sort = builder.AddInstruction(HloInstruction::CreateSort( + ShapeUtil::MakeTupleShape({bf16_shape, bf16_shape}), 0, key, {value})); + + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_TRUE(Normalize(module)); + + EXPECT_EQ(sort->operand(0)->shape().element_type(), F32); + EXPECT_EQ(ShapeUtil::GetSubshape(sort->shape(), {0}).element_type(), F32); + EXPECT_NE(computation->root_instruction(), sort); + EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kTuple); +} + // Tests that the normalization should not cause unsupported mixed precision due // to resolving unsupported BF16 operand. TEST_F(BFloat16NormalizationTest, DoNotAddUnsupportedMixedPrecision) { @@ -308,13 +336,16 @@ TEST_F(BFloat16NormalizationTest, DoNotAddUnsupportedMixedPrecision) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); + PrecisionConfig precision_config; + precision_config.mutable_operand_precision()->Resize( + 2, PrecisionConfig::DEFAULT); HloInstruction* dot = builder.AddInstruction( - HloInstruction::CreateDot(bf16_shape, a, b, dot_dnums)); + HloInstruction::CreateDot(bf16_shape, a, b, dot_dnums, precision_config)); auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(Normalize(module.get())); + EXPECT_TRUE(Normalize(module)); EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kConvert); EXPECT_EQ(dot->shape().element_type(), F32); diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation.cc b/tensorflow/compiler/xla/service/bfloat16_propagation.cc index 2fb401c4289728f3f59538464c5b8ad49957985b..002be9c97098ef1f73446c458dae24bbc826a626 100644 --- a/tensorflow/compiler/xla/service/bfloat16_propagation.cc +++ b/tensorflow/compiler/xla/service/bfloat16_propagation.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/bfloat16_propagation.h" +#include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -81,7 +82,7 @@ void BFloat16Propagation::RevertIfFusionInternalBF16Changes( }; auto root = fusion->fused_instructions_computation()->root_instruction(); - tensorflow::gtl::FlatSet changed_root_buffers; + absl::flat_hash_set changed_root_buffers; auto root_changes_it = changes_to_bf16_.find(root); if (root_changes_it != changes_to_bf16_.end()) { @@ -407,7 +408,7 @@ void BFloat16Propagation::AdjustCalledComputationParameters( HloInstruction* hlo) { auto adjust_computation = [this, hlo](HloComputation* computation, - tensorflow::gtl::ArraySlice operands) { + absl::Span operands) { // Adjust parameters. CHECK_EQ(operands.size(), computation->num_parameters()); for (int64 i = 0; i < operands.size(); ++i) { @@ -500,7 +501,7 @@ void BFloat16Propagation::AdjustCalledComputationRoot(HloInstruction* hlo) { bool BFloat16Propagation::ResolveInconsistencyOfAliasingBuffersHelper( HloComputation* computation, - tensorflow::gtl::FlatSet* visited_computations) { + absl::flat_hash_set* visited_computations) { bool parameter_changed = false; auto insts = computation->MakeInstructionPostOrder(); // Do the adjustment on each instruction in the computation in reverse @@ -560,7 +561,7 @@ bool BFloat16Propagation::ResolveInconsistencyOfAliasingBuffersHelper( // another input parameter. A fixed point will be reached because the // parameters can only be changed from BF16 to F32, not the other way // around. - tensorflow::gtl::FlatSet visited_in_while; + absl::flat_hash_set visited_in_while; while (ResolveInconsistencyOfAliasingBuffersHelper(hlo->while_condition(), &visited_in_while) || ResolveInconsistencyOfAliasingBuffersHelper(hlo->while_body(), @@ -587,7 +588,7 @@ void BFloat16Propagation::ResolveInconsistencyOfAliasingBuffers( HloModule* module) { const auto& computations_topological_order = module->MakeComputationPostOrder(); - tensorflow::gtl::FlatSet resolved; + absl::flat_hash_set resolved; for (auto comp_it = computations_topological_order.rbegin(); comp_it != computations_topological_order.rend(); ++comp_it) { if (ContainsKey(resolved, *comp_it)) { @@ -675,10 +676,8 @@ Status BFloat16Propagation::ResolveConvertedConstants(HloModule* module) { continue; } if (!ShapeUtil::Equal(hlo->literal().shape(), hlo->shape())) { - TF_ASSIGN_OR_RETURN( - auto converted_literal, - hlo->literal().ConvertToShape(hlo->shape(), - /*round_f32_to_bf16=*/true)); + TF_ASSIGN_OR_RETURN(auto converted_literal, + hlo->literal().ConvertToShape(hlo->shape())); auto new_constant = computation->AddInstruction( HloInstruction::CreateConstant(std::move(converted_literal))); TF_RETURN_IF_ERROR(hlo->ReplaceAllUsesWith(new_constant)); diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation.h b/tensorflow/compiler/xla/service/bfloat16_propagation.h index 1ee64971ab53e1775294afde1c779369a838008a..5fcaa15c8356107af02e9099874a293d8350c51a 100644 --- a/tensorflow/compiler/xla/service/bfloat16_propagation.h +++ b/tensorflow/compiler/xla/service/bfloat16_propagation.h @@ -21,6 +21,8 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/service/bfloat16_support.h" #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -58,7 +60,7 @@ namespace xla { // BFloat16ConversionFolding. If other passes are needed after this pass, run // BFloat16MixedPrecisionRemoval first to undo some of the changes made by this // pass. -class BFloat16Propagation : public HloPassInterface { +class BFloat16Propagation : public HloModulePass { public: explicit BFloat16Propagation(const BFloat16Support* bfloat16_support); @@ -81,7 +83,7 @@ class BFloat16Propagation : public HloPassInterface { // The set of instructions to consider using bfloat16, computed in the forward // pass. - tensorflow::gtl::FlatSet consider_using_bfloat16_; + absl::flat_hash_set consider_using_bfloat16_; // *************************** // Functions called and state produced by the backward pass (from root to @@ -110,12 +112,12 @@ class BFloat16Propagation : public HloPassInterface { // The set of HloInstructions that have been visited in the // opportunity-finding pass. - tensorflow::gtl::FlatSet + absl::flat_hash_set instructions_visited_in_backward_pass_; // The set of HloComputations that have been visited in the // opportunity-finding pass. - tensorflow::gtl::FlatSet + absl::flat_hash_set computations_visited_in_backward_pass_; // *************************** @@ -131,7 +133,7 @@ class BFloat16Propagation : public HloPassInterface { // point is reached. bool ResolveInconsistencyOfAliasingBuffersHelper( HloComputation* computation, - tensorflow::gtl::FlatSet* visited_computations); + absl::flat_hash_set* visited_computations); // Makes the parameters of called computations match how they are called by // the given HLO. @@ -182,11 +184,11 @@ class BFloat16Propagation : public HloPassInterface { PrimitiveType target_type); // The set of F32 HLO values that must be kept in F32. - tensorflow::gtl::FlatSet values_that_must_be_kept_as_f32_; + absl::flat_hash_set values_that_must_be_kept_as_f32_; // Mapping from each HloComputation to the number of callers to it in the // module. Populated at the beginning of this pass. - tensorflow::gtl::FlatMap caller_counts_; + absl::flat_hash_map caller_counts_; // We first store the potential F32-to-BF16 changes to changes_to_bf16_, which // are subject to further adjustment, then finally applied to the HLOs. This @@ -195,8 +197,7 @@ class BFloat16Propagation : public HloPassInterface { // // For each HloInstruction, changes_to_bf16_ stores the affected buffers in // the output as a map from in-place pointers to subshapes to shape indices. - tensorflow::gtl::FlatMap> + absl::flat_hash_map> changes_to_bf16_; // Whether the last processed HLO module has been changed by this pass. diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc index 69b654d30e42b1ed69304206f09120e86831d468..e032b5c624c0151fd63c870e0f21ec97656d625f 100644 --- a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -55,8 +55,12 @@ class TestBFloat16Support : public BFloat16Support { } }; -class BFloat16PropagationTest : public HloTestBase { +class BFloat16PropagationTest : public HloVerifiedTestBase { protected: + BFloat16PropagationTest() + : HloVerifiedTestBase(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/true) {} + // Runs the propagation pass on the given module, and returns whether the // module is changed after this pass. bool PropagatePrecision(HloModule* module) { @@ -77,6 +81,16 @@ class BFloat16PropagationTest : public HloTestBase { inst->users()[0]->opcode() == HloOpcode::kConvert && inst->users()[0]->shape().element_type() == BF16; } + + std::unique_ptr CreateDot(const Shape& shape, + HloInstruction* lhs, + HloInstruction* rhs) { + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + return HloInstruction::CreateDot(shape, lhs, rhs, dot_dnums, + DefaultPrecisionConfig(2)); + } }; // Tests that BF16 can propagate through select over non-tuple buffers, but not @@ -95,22 +109,22 @@ TEST_F(BFloat16PropagationTest, PropagateThroughSelectButNotAdd) { HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a, b)); HloInstruction* add1 = builder.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kAdd, add0, b)); - HloInstruction* pred = builder.AddInstruction( - HloInstruction::CreateBinary(shape, HloOpcode::kEq, a, b)); + HloInstruction* pred = builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(PRED, {2, 4}), HloOpcode::kEq, a, b)); HloInstruction* sel = builder.AddInstruction( HloInstruction::CreateTernary(shape, HloOpcode::kSelect, pred, c, add1)); HloInstruction* xpose = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(F32, {4, 2}), sel, {1, 0})); - HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kDot, xpose, a)); - HloInstruction* root = builder.AddInstruction( - HloInstruction::CreateBinary(shape, HloOpcode::kAdd, dot, dot)); + HloInstruction* dot = builder.AddInstruction( + CreateDot(ShapeUtil::MakeShape(F32, {4, 4}), xpose, a)); + HloInstruction* root = builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kAdd, dot, dot)); auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module.get())); + EXPECT_TRUE(PropagatePrecision(module)); EXPECT_EQ(computation->root_instruction(), root); EXPECT_TRUE(OutputsBF16(xpose)); @@ -136,13 +150,12 @@ TEST_F(BFloat16PropagationTest, ConvertConstantLiteral) { HloInstruction::CreateConstant(LiteralUtil::CreateFromArray(array_a))); HloInstruction* b = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateFromArray(array_b))); - HloInstruction* dot = builder.AddInstruction( - HloInstruction::CreateBinary(shape, HloOpcode::kDot, a, b)); + HloInstruction* dot = builder.AddInstruction(CreateDot(shape, a, b)); auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module.get())); + EXPECT_TRUE(PropagatePrecision(module)); EXPECT_EQ(computation->root_instruction(), dot); EXPECT_TRUE(OutputsBF16(dot->operand(0))); @@ -150,10 +163,10 @@ TEST_F(BFloat16PropagationTest, ConvertConstantLiteral) { EXPECT_EQ(dot->operand(0)->opcode(), HloOpcode::kConstant); EXPECT_EQ(dot->operand(1)->opcode(), HloOpcode::kConstant); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::ConvertF32ToBF16(*LiteralUtil::CreateFromArray(array_a)), + LiteralUtil::ConvertF32ToBF16(LiteralUtil::CreateFromArray(array_a)), dot->operand(0)->literal())); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::ConvertF32ToBF16(*LiteralUtil::CreateFromArray(array_b)), + LiteralUtil::ConvertF32ToBF16(LiteralUtil::CreateFromArray(array_b)), dot->operand(1)->literal())); } @@ -189,8 +202,8 @@ TEST_F(BFloat16PropagationTest, PropagateThroughTuples) { builder.AddInstruction(HloInstruction::CreateGetTupleElement( tuple0->shape(), tuple1, 0)), 0)); - HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kDot, lhs, rhs)); + HloInstruction* dot = builder.AddInstruction( + CreateDot(ShapeUtil::MakeShape(F32, {4, 4}), lhs, rhs)); HloInstruction* output_tuple = builder.AddInstruction(HloInstruction::CreateTuple({dot, add2})); @@ -198,7 +211,7 @@ TEST_F(BFloat16PropagationTest, PropagateThroughTuples) { auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module.get())); + EXPECT_TRUE(PropagatePrecision(module)); EXPECT_EQ(computation->root_instruction(), output_tuple); EXPECT_TRUE(OutputsBF16(xpose)); @@ -231,13 +244,13 @@ TEST_F(BFloat16PropagationTest, SameValueReferencedTwice) { HloInstruction::CreateGetTupleElement(add1->shape(), tuple, 1)); // lhs is the transpose of add1, and rhs is a get-tuple-element aliasing add1. - HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kDot, lhs, rhs)); + HloInstruction* dot = builder.AddInstruction( + CreateDot(ShapeUtil::MakeShape(F32, {4, 4}), lhs, rhs)); auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module.get())); + EXPECT_TRUE(PropagatePrecision(module)); EXPECT_EQ(computation->root_instruction(), dot); EXPECT_TRUE(OutputsBF16(add1)); @@ -249,7 +262,7 @@ TEST_F(BFloat16PropagationTest, SameValueReferencedTwice) { // Tests that a non-fusion computation's root should not be changed. TEST_F(BFloat16PropagationTest, DoNotChangeComputationRoot) { auto builder = HloComputation::Builder(TestName()); - Shape shape = ShapeUtil::MakeShape(F32, {2, 4}); + Shape shape = ShapeUtil::MakeShape(F32, {4, 4}); HloInstruction* a = builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a")); @@ -258,8 +271,7 @@ TEST_F(BFloat16PropagationTest, DoNotChangeComputationRoot) { HloInstruction* add = builder.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a, b)); - HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kDot, add, add)); + HloInstruction* dot = builder.AddInstruction(CreateDot(shape, add, add)); HloInstruction* tuple = builder.AddInstruction(HloInstruction::CreateTuple({add, dot})); @@ -267,7 +279,7 @@ TEST_F(BFloat16PropagationTest, DoNotChangeComputationRoot) { auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_FALSE(PropagatePrecision(module.get())); + EXPECT_FALSE(PropagatePrecision(module)); EXPECT_EQ(computation->root_instruction(), tuple); EXPECT_FALSE(OutputsBF16(add)); @@ -277,7 +289,7 @@ TEST_F(BFloat16PropagationTest, DoNotChangeComputationRoot) { TEST_F(BFloat16PropagationTest, PropagateThroughFusion) { auto module = CreateNewModule(); auto builder = HloComputation::Builder(TestName()); - Shape shape = ShapeUtil::MakeShape(F32, {2, 4}); + Shape shape = ShapeUtil::MakeShape(F32, {4, 4}); HloInstruction* param = builder.AddInstruction( HloInstruction::CreateParameter(0, shape, "param")); @@ -303,15 +315,14 @@ TEST_F(BFloat16PropagationTest, PropagateThroughFusion) { HloInstruction::CreateGetTupleElement(shape, p_f1, 0)); HloInstruction* b_f1 = builder_f1.AddInstruction( HloInstruction::CreateGetTupleElement(shape, p_f1, 1)); - HloInstruction* dot = builder_f1.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kDot, a_f1, b_f1)); + HloInstruction* dot = builder_f1.AddInstruction(CreateDot(shape, a_f1, b_f1)); auto comp_f1 = module->AddEmbeddedComputation(builder_f1.Build()); auto fusion1 = builder.AddInstruction(HloInstruction::CreateFusion( dot->shape(), HloInstruction::FusionKind::kCustom, {fusion0}, comp_f1)); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module.get())); + EXPECT_TRUE(PropagatePrecision(module)); EXPECT_EQ(computation->root_instruction(), fusion1); EXPECT_TRUE(OutputsBF16(add)); @@ -326,7 +337,7 @@ TEST_F(BFloat16PropagationTest, PropagateThroughFusion) { TEST_F(BFloat16PropagationTest, DiscardFusionInternalBF16Changes) { auto module = CreateNewModule(); auto builder = HloComputation::Builder(TestName()); - Shape shape = ShapeUtil::MakeShape(F32, {2, 4}); + Shape shape = ShapeUtil::MakeShape(F32, {4, 4}); HloInstruction* param = builder.AddInstruction( HloInstruction::CreateParameter(0, shape, "param")); @@ -340,15 +351,15 @@ TEST_F(BFloat16PropagationTest, DiscardFusionInternalBF16Changes) { builder_f.AddInstruction(HloInstruction::CreateParameter(1, shape, "b")); HloInstruction* add_f = builder_f.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a_f, b_f)); - HloInstruction* dot_f = builder_f.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kDot, add_f, add_f)); + HloInstruction* dot_f = + builder_f.AddInstruction(CreateDot(shape, add_f, add_f)); auto comp_f = module->AddEmbeddedComputation(builder_f.Build()); auto fusion = builder.AddInstruction(HloInstruction::CreateFusion( dot_f->shape(), HloInstruction::FusionKind::kCustom, {add, add}, comp_f)); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_FALSE(PropagatePrecision(module.get())); + EXPECT_FALSE(PropagatePrecision(module)); EXPECT_EQ(computation->root_instruction(), fusion); } @@ -390,12 +401,11 @@ TEST_F(BFloat16PropagationTest, ConvertTupleFusionElementIfUsedByAdd) { HloInstruction::CreateGetTupleElement(shape, fusion, 0)); HloInstruction* gte1 = builder.AddInstruction( HloInstruction::CreateGetTupleElement(shape, fusion, 1)); - HloInstruction* dot = builder.AddInstruction( - HloInstruction::CreateBinary(shape, HloOpcode::kDot, gte0, gte1)); + HloInstruction* dot = builder.AddInstruction(CreateDot(shape, gte0, gte1)); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module.get())); + EXPECT_TRUE(PropagatePrecision(module)); EXPECT_EQ(computation->root_instruction(), dot); EXPECT_TRUE(OutputsBF16(gte0)); @@ -440,12 +450,12 @@ TEST_F(BFloat16PropagationTest, SelectOverTuples) { HloInstruction* xpose = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(F32, {4, 2}), gte0, {1, 0})); - HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kDot, xpose, gte1)); + HloInstruction* dot = builder.AddInstruction( + CreateDot(ShapeUtil::MakeShape(F32, {4, 4}), xpose, gte1)); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module.get())); + EXPECT_TRUE(PropagatePrecision(module)); EXPECT_EQ(computation->root_instruction(), dot); EXPECT_FALSE(OutputsBF16(add0)); @@ -472,31 +482,36 @@ TEST_F(BFloat16PropagationTest, PropagateThroughSimpleWhile) { auto builder_cond = HloComputation::Builder("cond"); auto cond_param = builder_cond.AddInstruction( HloInstruction::CreateParameter(0, shape, "cond_param")); - auto cond_dot = builder_cond.AddInstruction(HloInstruction::CreateBinary( - shape, HloOpcode::kDot, cond_param, cond_param)); + auto cond_dot = + builder_cond.AddInstruction(CreateDot(shape, cond_param, cond_param)); auto cond_root = builder_cond.AddInstruction(HloInstruction::CreateBinary( ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt, - builder_cond.AddInstruction(HloInstruction::CreateSlice( - ShapeUtil::MakeShape(F32, {}), cond_dot, {0, 0}, {1, 1}, {1, 1})), - builder_cond.AddInstruction(HloInstruction::CreateSlice( - ShapeUtil::MakeShape(F32, {}), cond_dot, {1, 1}, {2, 2}, {1, 1})))); + builder_cond.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {}), + builder_cond.AddInstruction( + HloInstruction::CreateSlice(ShapeUtil::MakeShape(F32, {1, 1}), + cond_dot, {0, 0}, {1, 1}, {1, 1})))), + builder_cond.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {}), + builder_cond.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {1, 1}), cond_dot, {1, 1}, {2, 2}, + {1, 1})))))); auto cond = module->AddEmbeddedComputation(builder_cond.Build()); auto builder_body = HloComputation::Builder("body"); auto body_param = builder_body.AddInstruction( HloInstruction::CreateParameter(0, shape, "body_param")); - auto body_dot = builder_body.AddInstruction(HloInstruction::CreateBinary( - shape, HloOpcode::kDot, body_param, body_param)); + auto body_dot = + builder_body.AddInstruction(CreateDot(shape, body_param, body_param)); auto body = module->AddEmbeddedComputation(builder_body.Build()); auto while_hlo = builder.AddInstruction( HloInstruction::CreateWhile(shape, cond, body, add)); - auto dot = builder.AddInstruction(HloInstruction::CreateBinary( - shape, HloOpcode::kDot, while_hlo, while_hlo)); + auto dot = builder.AddInstruction(CreateDot(shape, while_hlo, while_hlo)); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module.get())); + EXPECT_TRUE(PropagatePrecision(module)); EXPECT_EQ(computation->root_instruction(), dot); EXPECT_TRUE( @@ -528,10 +543,16 @@ TEST_F(BFloat16PropagationTest, HloInstruction::CreateParameter(0, shape, "cond_param")); builder_cond.AddInstruction(HloInstruction::CreateBinary( ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt, - builder_cond.AddInstruction(HloInstruction::CreateSlice( - ShapeUtil::MakeShape(F32, {}), cond_param, {0, 0}, {1, 1}, {1, 1})), - builder_cond.AddInstruction(HloInstruction::CreateSlice( - ShapeUtil::MakeShape(F32, {}), cond_param, {1, 1}, {2, 2}, {1, 1})))); + builder_cond.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {}), + builder_cond.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {1, 1}), cond_param, {0, 0}, {1, 1}, + {1, 1})))), + builder_cond.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {}), + builder_cond.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {1, 1}), cond_param, {1, 1}, {2, 2}, + {1, 1})))))); auto cond = module->AddEmbeddedComputation(builder_cond.Build()); auto builder_body = HloComputation::Builder("body"); @@ -552,11 +573,10 @@ TEST_F(BFloat16PropagationTest, auto while_hlo = builder.AddInstruction( HloInstruction::CreateWhile(shape, cond, body, add)); - auto dot = builder.AddInstruction(HloInstruction::CreateBinary( - shape, HloOpcode::kDot, while_hlo, while_hlo)); + auto dot = builder.AddInstruction(CreateDot(shape, while_hlo, while_hlo)); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_FALSE(PropagatePrecision(module.get())); + EXPECT_FALSE(PropagatePrecision(module)); EXPECT_EQ(computation->root_instruction(), dot); EXPECT_FALSE(OutputsBF16(add)); EXPECT_FALSE(OutputsBF16(body_fusion)); @@ -593,14 +613,20 @@ TEST_F(BFloat16PropagationTest, PropagateThroughTupleWhile) { // This add should prevent RHS from using BF16 auto cond_add_rhs = builder_cond.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kAdd, cond_rhs, cond_rhs)); - auto cond_dot = builder_cond.AddInstruction(HloInstruction::CreateBinary( - shape, HloOpcode::kDot, cond_lhs, cond_add_rhs)); + auto cond_dot = + builder_cond.AddInstruction(CreateDot(shape, cond_lhs, cond_add_rhs)); builder_cond.AddInstruction(HloInstruction::CreateBinary( ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt, - builder_cond.AddInstruction(HloInstruction::CreateSlice( - ShapeUtil::MakeShape(F32, {}), cond_dot, {0, 0}, {1, 1}, {1, 1})), - builder_cond.AddInstruction(HloInstruction::CreateSlice( - ShapeUtil::MakeShape(F32, {}), cond_dot, {1, 1}, {2, 2}, {1, 1})))); + builder_cond.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {}), + builder_cond.AddInstruction( + HloInstruction::CreateSlice(ShapeUtil::MakeShape(F32, {1, 1}), + cond_dot, {0, 0}, {1, 1}, {1, 1})))), + builder_cond.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {}), + builder_cond.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {1, 1}), cond_dot, {1, 1}, {2, 2}, + {1, 1})))))); auto cond = module->AddEmbeddedComputation(builder_cond.Build()); auto builder_body = HloComputation::Builder("body"); @@ -610,10 +636,10 @@ TEST_F(BFloat16PropagationTest, PropagateThroughTupleWhile) { HloInstruction::CreateGetTupleElement(shape, body_param, 0)); auto body_rhs = builder_body.AddInstruction( HloInstruction::CreateGetTupleElement(shape, body_param, 1)); - auto body_dot1 = builder_body.AddInstruction( - HloInstruction::CreateBinary(shape, HloOpcode::kDot, body_lhs, body_rhs)); - auto body_dot2 = builder_body.AddInstruction( - HloInstruction::CreateBinary(shape, HloOpcode::kDot, body_rhs, body_lhs)); + auto body_dot1 = + builder_body.AddInstruction(CreateDot(shape, body_lhs, body_rhs)); + auto body_dot2 = + builder_body.AddInstruction(CreateDot(shape, body_rhs, body_lhs)); auto body_transpose = builder_body.AddInstruction( HloInstruction::CreateTranspose(shape, body_dot2, {0, 1})); builder_body.AddInstruction( @@ -627,11 +653,10 @@ TEST_F(BFloat16PropagationTest, PropagateThroughTupleWhile) { HloInstruction::CreateGetTupleElement(shape, while_hlo, 0)); auto rhs = builder.AddInstruction( HloInstruction::CreateGetTupleElement(shape, while_hlo, 1)); - auto dot = builder.AddInstruction( - HloInstruction::CreateBinary(shape, HloOpcode::kDot, lhs, rhs)); + auto dot = builder.AddInstruction(CreateDot(shape, lhs, rhs)); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module.get())); + EXPECT_TRUE(PropagatePrecision(module)); EXPECT_EQ(computation->root_instruction(), dot); EXPECT_TRUE(OutputsBF16(lhs)); @@ -683,14 +708,20 @@ TEST_F(BFloat16PropagationTest, DoNotPropagateWhilesCallingSameComputation) { auto cond0_add_rhs = builder_cond0.AddInstruction(HloInstruction::CreateBinary( shape, HloOpcode::kAdd, cond0_rhs, cond0_rhs)); - auto cond0_dot = builder_cond0.AddInstruction(HloInstruction::CreateBinary( - shape, HloOpcode::kDot, cond0_lhs, cond0_add_rhs)); + auto cond0_dot = + builder_cond0.AddInstruction(CreateDot(shape, cond0_lhs, cond0_add_rhs)); builder_cond0.AddInstruction(HloInstruction::CreateBinary( ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt, - builder_cond0.AddInstruction(HloInstruction::CreateSlice( - ShapeUtil::MakeShape(F32, {}), cond0_dot, {0, 0}, {1, 1}, {1, 1})), - builder_cond0.AddInstruction(HloInstruction::CreateSlice( - ShapeUtil::MakeShape(F32, {}), cond0_dot, {1, 1}, {2, 2}, {1, 1})))); + builder_cond0.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {}), + builder_cond0.AddInstruction( + HloInstruction::CreateSlice(ShapeUtil::MakeShape(F32, {1, 1}), + cond0_dot, {0, 0}, {1, 1}, {1, 1})))), + builder_cond0.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {}), + builder_cond0.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {1, 1}), cond0_dot, {1, 1}, {2, 2}, + {1, 1})))))); auto cond0 = module->AddEmbeddedComputation(builder_cond0.Build()); // Condition computation for the second while. @@ -705,14 +736,20 @@ TEST_F(BFloat16PropagationTest, DoNotPropagateWhilesCallingSameComputation) { auto cond1_add_lhs = builder_cond1.AddInstruction(HloInstruction::CreateBinary( shape, HloOpcode::kAdd, cond1_lhs, cond1_lhs)); - auto cond1_dot = builder_cond1.AddInstruction(HloInstruction::CreateBinary( - shape, HloOpcode::kDot, cond1_add_lhs, cond1_rhs)); + auto cond1_dot = + builder_cond1.AddInstruction(CreateDot(shape, cond1_add_lhs, cond1_rhs)); builder_cond1.AddInstruction(HloInstruction::CreateBinary( ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt, - builder_cond1.AddInstruction(HloInstruction::CreateSlice( - ShapeUtil::MakeShape(F32, {}), cond1_dot, {0, 0}, {1, 1}, {1, 1})), - builder_cond1.AddInstruction(HloInstruction::CreateSlice( - ShapeUtil::MakeShape(F32, {}), cond1_dot, {1, 1}, {2, 2}, {1, 1})))); + builder_cond1.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {}), + builder_cond1.AddInstruction( + HloInstruction::CreateSlice(ShapeUtil::MakeShape(F32, {1, 1}), + cond1_dot, {0, 0}, {1, 1}, {1, 1})))), + builder_cond1.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {}), + builder_cond1.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {1, 1}), cond1_dot, {1, 1}, {2, 2}, + {1, 1})))))); auto cond1 = module->AddEmbeddedComputation(builder_cond1.Build()); // Body computation shared by both whiles. @@ -723,8 +760,8 @@ TEST_F(BFloat16PropagationTest, DoNotPropagateWhilesCallingSameComputation) { HloInstruction::CreateGetTupleElement(shape, body_param, 0)); auto body_rhs = builder_body.AddInstruction( HloInstruction::CreateGetTupleElement(shape, body_param, 1)); - auto body_dot = builder_body.AddInstruction( - HloInstruction::CreateBinary(shape, HloOpcode::kDot, body_lhs, body_rhs)); + auto body_dot = + builder_body.AddInstruction(CreateDot(shape, body_lhs, body_rhs)); builder_body.AddInstruction( HloInstruction::CreateTuple({body_dot, body_rhs})); auto body = module->AddEmbeddedComputation(builder_body.Build()); @@ -734,23 +771,22 @@ TEST_F(BFloat16PropagationTest, DoNotPropagateWhilesCallingSameComputation) { auto while1 = builder.AddInstruction( HloInstruction::CreateWhile(tuple1->shape(), cond1, body, tuple1)); - auto lhs = builder.AddInstruction(HloInstruction::CreateBinary( - shape, HloOpcode::kDot, - builder.AddInstruction( - HloInstruction::CreateGetTupleElement(shape, while0, 0)), - builder.AddInstruction( - HloInstruction::CreateGetTupleElement(shape, while0, 1)))); - auto rhs = builder.AddInstruction(HloInstruction::CreateBinary( - shape, HloOpcode::kDot, - builder.AddInstruction( - HloInstruction::CreateGetTupleElement(shape, while1, 0)), - builder.AddInstruction( - HloInstruction::CreateGetTupleElement(shape, while1, 1)))); - auto dot = builder.AddInstruction( - HloInstruction::CreateBinary(shape, HloOpcode::kDot, lhs, rhs)); + auto lhs = builder.AddInstruction( + CreateDot(shape, + builder.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, while0, 0)), + builder.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, while0, 1)))); + auto rhs = builder.AddInstruction( + CreateDot(shape, + builder.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, while1, 0)), + builder.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, while1, 1)))); + auto dot = builder.AddInstruction(CreateDot(shape, lhs, rhs)); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module.get())); + EXPECT_TRUE(PropagatePrecision(module)); EXPECT_FALSE(OutputsBF16(body_dot)); EXPECT_FALSE(OutputsBF16(body_rhs)); EXPECT_FALSE(OutputsBF16(body_lhs)); @@ -792,7 +828,7 @@ TEST_F(BFloat16PropagationTest, NoopConversionRemoved) { auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module.get())); + EXPECT_TRUE(PropagatePrecision(module)); EXPECT_EQ(computation->root_instruction(), add2); EXPECT_EQ(add2->operand(0), add0); @@ -821,15 +857,14 @@ TEST_F(BFloat16PropagationTest, TupleDomain) { HloInstruction::CreateGetTupleElement(shape, domain, 0)); HloInstruction* b_gte = builder.AddInstruction( HloInstruction::CreateGetTupleElement(shape, domain, 1)); - HloInstruction* dot = builder.AddInstruction( - HloInstruction::CreateBinary(shape, HloOpcode::kDot, a_gte, b_gte)); + HloInstruction* dot = builder.AddInstruction(CreateDot(shape, a_gte, b_gte)); HloInstruction* root = builder.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kAdd, dot, dot)); auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module.get())); + EXPECT_TRUE(PropagatePrecision(module)); EXPECT_EQ(computation->root_instruction(), root); // test BF16 propagated through domain @@ -867,15 +902,15 @@ TEST_F(BFloat16PropagationTest, TupleDomainNoPropagation) { HloInstruction::CreateTranspose(shape, a_gte, {0, 1})); HloInstruction* b_trans = builder.AddInstruction( HloInstruction::CreateTranspose(shape, b_gte, {0, 1})); - HloInstruction* dot = builder.AddInstruction( - HloInstruction::CreateBinary(shape, HloOpcode::kDot, a_trans, b_trans)); + HloInstruction* dot = + builder.AddInstruction(CreateDot(shape, a_trans, b_trans)); HloInstruction* root = builder.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kAdd, dot, dot)); auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module.get())); + EXPECT_TRUE(PropagatePrecision(module)); EXPECT_EQ(computation->root_instruction(), root); EXPECT_TRUE(OutputsBF16(a_trans)); diff --git a/tensorflow/compiler/xla/service/bfloat16_support.cc b/tensorflow/compiler/xla/service/bfloat16_support.cc index 23645346e6f491beb5171cc839c013ce5f83d789..5b48f10505e78c035608d4c575501e4623218987 100644 --- a/tensorflow/compiler/xla/service/bfloat16_support.cc +++ b/tensorflow/compiler/xla/service/bfloat16_support.cc @@ -78,8 +78,10 @@ bool BFloat16Support::EffectiveOperandPrecisionIsOutputPrecision( const HloInstruction& hlo, int64 operand_index) { switch (hlo.opcode()) { case HloOpcode::kAbs: + case HloOpcode::kAllToAll: case HloOpcode::kBroadcast: case HloOpcode::kClamp: + case HloOpcode::kCollectivePermute: case HloOpcode::kConcatenate: case HloOpcode::kConvert: case HloOpcode::kCopy: diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc index c8c36ae60ed0e53234523fa0f7a904d9dbbe06d2..d5d6a044a81303425495202d8a98c6735b0b8b89 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment.cc @@ -22,14 +22,16 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/buffer_value_containers.h" #include "tensorflow/compiler/xla/service/heap_simulator.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" -#include "tensorflow/compiler/xla/service/hlo_scheduling.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" @@ -37,17 +39,15 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/lib/strings/numbers.h" -#include "tensorflow/core/lib/strings/stringprintf.h" namespace xla { namespace { +using absl::flat_hash_map; +using absl::flat_hash_set; using absl::StrAppend; -using ::tensorflow::gtl::FlatMap; -using ::tensorflow::gtl::FlatSet; -using ::tensorflow::strings::Appendf; +using absl::StrAppendFormat; using ::tensorflow::strings::HumanReadableNumBytes; -using ::tensorflow::strings::Printf; template string ColocatedBufferSetsToString(const T& container, const char* title) { @@ -59,12 +59,65 @@ string ColocatedBufferSetsToString(const T& container, const char* title) { return result; } -// Walk the call graph of the HLO module and place each computation into either -// thread_local_computations or global_computations depending upon whether the -// computation requires thread-local allocations or global allocations. The -// elements in thread_local_computations and global_computations are in post -// order (if computation A has an instruction which calls computation B, then A -// will appear after B in the vector). +// Checks that points-to set of 'instruction' is unambiguous and distinct +// (ensured by CopyInsertion), then adds the buffer from the points-to set at +// 'index' to 'colocated_set'. +const LogicalBuffer* AddBufferToColocatedSet( + const HloInstruction* instruction, const ShapeIndex& index, + const TuplePointsToAnalysis& points_to_analysis, + std::vector* colocated_set) { + // CopyInsertion ensures root points-to set is unambiguous and distinct. + const auto& points_to = points_to_analysis.GetPointsToSet(instruction); + DCHECK(!points_to.IsAmbiguous()); + colocated_set->push_back(points_to.element(index)[0]); + return colocated_set->back(); +} + +// Given the interference map of a graph (the list of interfering node indices +// for each node), perform graph coloring such that interfering nodes are +// assigned to different colors. Returns the assigned color of the nodes, where +// the colors are represented as integer values [0, color_count). +std::vector ColorInterferenceGraph( + const std::vector>& interference_map) { + const int64 node_count = interference_map.size(); + + // Sort the nodes such that we assign nodes with more interference first. This + // relies on the common heuristic of assigning the most constrained node + // first, but it would be good to investigate other ordering heuristics too. + std::vector nodes(node_count); + std::iota(nodes.begin(), nodes.end(), 0); + std::sort(nodes.begin(), nodes.end(), + [&interference_map](const int64 i, const int64 j) { + return interference_map[i].size() > interference_map[j].size(); + }); + + const int64 kColorUnassigned = -1; + std::vector assigned_colors(node_count, kColorUnassigned); + for (int64 node : nodes) { + // Mark the colors that are already assigned to the neighbors. + std::vector available_colors(node_count, true); + for (int64 neighbor : interference_map[node]) { + int64 color = assigned_colors[neighbor]; + if (color != kColorUnassigned) { + available_colors[color] = false; + } + } + + // Find the color that is not yet assigned to the neighbors. + int64 color = kColorUnassigned; + for (color = 0; color < available_colors.size(); ++color) { + if (available_colors[color]) { + break; + } + } + CHECK_NE(color, kColorUnassigned); + assigned_colors[node] = color; + } + return assigned_colors; +} + +} // namespace + Status GatherComputationsByAllocationType( const HloModule* module, std::vector* thread_local_computations, @@ -77,8 +130,8 @@ Status GatherComputationsByAllocationType( // Sets for quickly checking membership. Computations are returned in vectors // for stable iteration. - FlatSet thread_local_set; - FlatSet global_set; + flat_hash_set thread_local_set; + flat_hash_set global_set; while (!worklist.empty()) { auto worklist_front = worklist.front(); @@ -105,7 +158,7 @@ Status GatherComputationsByAllocationType( return InvalidArgument( "computation %s has conflicting allocation requirements (global " "and thread-local)", - computation->name().c_str()); + computation->name()); } if (is_thread_local) { @@ -128,7 +181,7 @@ Status GatherComputationsByAllocationType( return InvalidArgument( "computation %s cannot contain call/while op because it " "requires thread-local buffer allocations", - computation->name().c_str()); + computation->name()); } worklist.push_back(std::make_pair(subcomputation, false)); // Not thread local. @@ -145,9 +198,8 @@ Status GatherComputationsByAllocationType( true)); // Thread local. break; default: - return InternalError( - "Unexpected calling opcode: %s", - HloOpcodeString(instruction->opcode()).c_str()); + return InternalError("Unexpected calling opcode: %s", + HloOpcodeString(instruction->opcode())); } } } @@ -167,65 +219,6 @@ Status GatherComputationsByAllocationType( return Status::OK(); } -// Checks that points-to set of 'instruction' is unambiguous and distinct -// (ensured by CopyInsertion), then adds the buffer from the points-to set at -// 'index' to 'colocated_set'. -const LogicalBuffer* AddBufferToColocatedSet( - const HloInstruction* instruction, const ShapeIndex& index, - const TuplePointsToAnalysis& points_to_analysis, - std::vector* colocated_set) { - // CopyInsertion ensures root points-to set is unambiguous and distinct. - const auto& points_to = points_to_analysis.GetPointsToSet(instruction); - DCHECK(!points_to.IsAmbiguous()); - colocated_set->push_back(points_to.element(index)[0]); - return colocated_set->back(); -} - -// Given the interference map of a graph (the list of interfering node indices -// for each node), perform graph coloring such that interfering nodes are -// assigned to different colors. Returns the assigned color of the nodes, where -// the colors are represented as integer values [0, color_count). -std::vector ColorInterferenceGraph( - const std::vector>& interference_map) { - const int64 node_count = interference_map.size(); - - // Sort the nodes such that we assign nodes with more interference first. This - // relies on the common heuristic of assigning the most constrained node - // first, but it would be good to investigate other ordering heuristics too. - std::vector nodes(node_count); - std::iota(nodes.begin(), nodes.end(), 0); - std::sort(nodes.begin(), nodes.end(), - [&interference_map](const int64 i, const int64 j) { - return interference_map[i].size() > interference_map[j].size(); - }); - - const int64 kColorUnassigned = -1; - std::vector assigned_colors(node_count, kColorUnassigned); - for (int64 node : nodes) { - // Mark the colors that are already assigned to the neighbors. - std::vector available_colors(node_count, true); - for (int64 neighbor : interference_map[node]) { - int64 color = assigned_colors[neighbor]; - if (color != kColorUnassigned) { - available_colors[color] = false; - } - } - - // Find the color that is not yet assigned to the neighbors. - int64 color = kColorUnassigned; - for (color = 0; color < available_colors.size(); ++color) { - if (available_colors[color]) { - break; - } - } - CHECK_NE(color, kColorUnassigned); - assigned_colors[node] = color; - } - return assigned_colors; -} - -} // namespace - size_t BufferAllocation::Slice::Hasher::operator()(Slice s) const { uint64 h = std::hash()(s.index()); h = tensorflow::Hash64Combine(h, std::hash()(s.offset())); @@ -246,7 +239,7 @@ BufferAllocation::Slice BufferAllocation::GetSlice( void BufferAllocation::AddAssignment(const LogicalBuffer& buffer, int64 offset, int64 size) { - VLOG(4) << "Trying to add " << buffer << " to " << this; + VLOG(4) << "Trying to add " << buffer << " to allocation #" << index(); CHECK(assigned_buffers_.count(&buffer) == 0) << "LogicalBuffer " << buffer << " already assigned to allocation " << index_; @@ -296,7 +289,7 @@ BufferAllocationProto BufferAllocation::ToProto() const { string BufferAllocation::ToString() const { string output; - Appendf(&output, "allocation %lld: %p, size %lld", index_, this, size()); + StrAppendFormat(&output, "allocation %d: %p, size %d", index_, this, size()); if (color().value() != 0) { StrAppend(&output, ", color ", color().value()); } @@ -328,11 +321,10 @@ string BufferAllocation::ToString() const { }); for (const LogicalBuffer* buffer : sorted_buffers) { const OffsetSize& offset_size = FindOrDie(assigned_buffers_, buffer); - StrAppend(&output, - tensorflow::strings::Printf( - " %s [%lld,%lld]: %s\n", buffer->ToString().c_str(), - offset_size.offset, offset_size.size, - ShapeUtil::HumanStringWithLayout(buffer->shape()).c_str())); + StrAppend(&output, absl::StrFormat( + " %s [%d,%d]: %s\n", buffer->ToString(), + offset_size.offset, offset_size.size, + ShapeUtil::HumanStringWithLayout(buffer->shape()))); } return output; } @@ -425,7 +417,7 @@ StatusOr BufferAssignment::GetUniqueSlice( return FailedPrecondition( "BufferAllocation::Slice for instruction %s at index %s cannot " "be determined at compile-time.", - instruction->name().c_str(), index.ToString().c_str()); + instruction->name(), index.ToString()); } } else { VLOG(3) << "No allocation"; @@ -434,7 +426,7 @@ StatusOr BufferAssignment::GetUniqueSlice( if (result.allocation() == nullptr) { return FailedPrecondition( "BufferAllocation::Slice not assigned for instruction %s at index %s", - instruction->name().c_str(), index.ToString().c_str()); + instruction->name(), index.ToString()); } return result; } @@ -454,7 +446,7 @@ bool BufferAssignment::SharesSliceAtIndex( bool BufferAssignment::HaveDisjointSlices(const HloInstruction* hlo_a, const HloInstruction* hlo_b) const { using SliceSet = - FlatSet; + flat_hash_set; // Gets the slices all of instr's subshapes. If any subshape doesn't have an // assigned slice, returns the empty set. auto collect_slices = [&](const HloInstruction* instr) -> SliceSet { @@ -529,7 +521,8 @@ void BufferAssignment::AddAssignment(BufferAllocation* allocation, // BufferAllocation. void BufferAssignment::CombineTempAllocations() { VLOG(1) << "CombineTempAllocations()"; - FlatMap + flat_hash_map combined_allocation_map; // Move all temp allocations into a single run at the end of the allocations @@ -592,7 +585,8 @@ void BufferAssignment::CombineTempAllocations() { } // Update allocation indices to their new positions. - allocation_index_for_buffer_.clear_no_resize(); + allocation_index_for_buffer_.erase(allocation_index_for_buffer_.begin(), + allocation_index_for_buffer_.end()); for (size_t index = 0; index < allocations_.size(); ++index) { BufferAllocation* allocation = &allocations_[index]; allocation->set_index(index); @@ -626,18 +620,24 @@ Status BufferAssignment::ComputeSummaryStats() { } // Only compute total fragmentation if all computations have schedules. - SequentialHloOrdering::HloModuleSequence module_sequence; + HloSchedule schedule(module_); + bool schedule_complete = true; for (const auto& computation : module_->computations()) { - const std::vector* sequence = - liveness_->hlo_ordering().SequentialOrder(*computation); - if (sequence != nullptr) { - module_sequence.emplace(computation, *sequence); + if (!computation->IsFusionComputation()) { + const std::vector* sequence = + liveness_->hlo_ordering().SequentialOrder(*computation); + if (sequence == nullptr) { + schedule_complete = false; + } else { + schedule.set_sequence(computation, *sequence); + } } } - if (module_sequence.size() == module_->computation_count()) { + if (schedule_complete) { + TF_RETURN_IF_ERROR(schedule.Verify()); TF_ASSIGN_OR_RETURN( const int64 min_size, - HeapSimulator::MinimumMemoryForModule(module_sequence, buffer_size_)); + HeapSimulator::MinimumMemoryForModule(schedule, buffer_size_)); stats_.total_fragmentation_bytes = stats_.total_allocation_bytes - min_size; } @@ -646,30 +646,29 @@ Status BufferAssignment::ComputeSummaryStats() { string BufferAssignment::Stats::ToString() const { string s; - Appendf(&s, "BufferAssignment stats:\n"); - Appendf(&s, " parameter allocation: %10s\n", - HumanReadableNumBytes(parameter_allocation_bytes).c_str()); - Appendf(&s, " constant allocation: %10s\n", - HumanReadableNumBytes(constant_allocation_bytes).c_str()); - Appendf(&s, " maybe_live_out allocation: %10s\n", - HumanReadableNumBytes(maybe_live_out_allocation_bytes).c_str()); - Appendf(&s, " preallocated temp allocation: %10s\n", - HumanReadableNumBytes(preallocated_temp_allocation_bytes).c_str()); + StrAppendFormat(&s, "BufferAssignment stats:\n"); + StrAppendFormat(&s, " parameter allocation: %10s\n", + HumanReadableNumBytes(parameter_allocation_bytes)); + StrAppendFormat(&s, " constant allocation: %10s\n", + HumanReadableNumBytes(constant_allocation_bytes)); + StrAppendFormat(&s, " maybe_live_out allocation: %10s\n", + HumanReadableNumBytes(maybe_live_out_allocation_bytes)); + StrAppendFormat(&s, " preallocated temp allocation: %10s\n", + HumanReadableNumBytes(preallocated_temp_allocation_bytes)); if (preallocated_temp_fragmentation_bytes >= 0) { const double percent = 100. * preallocated_temp_fragmentation_bytes / preallocated_temp_allocation_bytes; - Appendf( + StrAppendFormat( &s, " preallocated temp fragmentation: %10s (%.2f%%)\n", - HumanReadableNumBytes(preallocated_temp_fragmentation_bytes).c_str(), - percent); + HumanReadableNumBytes(preallocated_temp_fragmentation_bytes), percent); } - Appendf(&s, " total allocation: %10s\n", - HumanReadableNumBytes(total_allocation_bytes).c_str()); + StrAppendFormat(&s, " total allocation: %10s\n", + HumanReadableNumBytes(total_allocation_bytes)); if (total_fragmentation_bytes >= 0) { const double percent = 100. * total_fragmentation_bytes / total_allocation_bytes; - Appendf(&s, " total fragmentation: %10s (%.2f%%)\n", - HumanReadableNumBytes(total_fragmentation_bytes).c_str(), percent); + StrAppendFormat(&s, " total fragmentation: %10s (%.2f%%)\n", + HumanReadableNumBytes(total_fragmentation_bytes), percent); } return s; } @@ -785,21 +784,6 @@ bool BufferAssigner::MaybeAssignBuffer(BufferAllocation* allocation, } } - if (allow_input_output_aliasing_ && allocation->maybe_live_out()) { - const HloComputation* entry_computation = - assignment->module_->entry_computation(); - for (auto param : entry_computation->parameter_instructions()) { - for (auto& param_buffer : - assignment->points_to_analysis().GetBuffersDefinedByInstruction( - param)) { - if (assignment->liveness().MayInterfere(*param_buffer, buffer)) { - VLOG(4) << "Can't assign: Parameter interference with result"; - return false; - } - } - } - } - // If the buffer is live out of the computation then it should only be // assigned a buffer which exactly fits the result to avoid wasting memory // (result buffers can have arbitrary lifetimes). @@ -817,9 +801,9 @@ bool BufferAssigner::MaybeAssignBuffer(BufferAllocation* allocation, Status BufferAssigner::AssignBuffersForComputation( const HloComputation* computation, bool is_thread_local, - const FlatSet& colocated_buffers, - const FlatSet& colocated_allocations, - FlatMap>* + const flat_hash_set& colocated_buffers, + const flat_hash_set& colocated_allocations, + flat_hash_map>* buffers_to_assign_sequentially, BufferAssignment* assignment) { // Buffers are sorted and assigned to BufferAllocations in decreasing order of @@ -838,7 +822,7 @@ Status BufferAssigner::AssignBuffersForComputation( // Generate a post order sort of instructions for sorting of the // LogicalBuffers. - FlatMap post_order_position; + flat_hash_map post_order_position; int position = 0; for (auto* instruction : computation->MakeInstructionPostOrder()) { post_order_position.emplace(instruction, position); @@ -855,8 +839,8 @@ Status BufferAssigner::AssignBuffersForComputation( // buffers_to_assign_sequentially map, even if we end up with an empty set // of buffers. This ensures we can correctly determine whether to run // whole-module heap simulation. - buffers_to_assign_sequentially->emplace(computation, - FlatSet()); + buffers_to_assign_sequentially->emplace( + computation, flat_hash_set()); } // Sort the LogicalBuffers first by size. We assign the larger LogicalBuffers @@ -1048,12 +1032,12 @@ Status BufferAssigner::AssignBuffersForComputation( return Status::OK(); } -FlatMap, - LogicalBuffer::Color::Hasher> +flat_hash_map, + LogicalBuffer::Color::Hasher> BufferAssigner::SplitBuffersByColor( - const FlatSet& buffers) { - FlatMap, - LogicalBuffer::Color::Hasher> + const flat_hash_set& buffers) { + flat_hash_map, + LogicalBuffer::Color::Hasher> color_map; for (auto buffer : buffers) { color_map[buffer->color()].insert(buffer); @@ -1062,27 +1046,42 @@ BufferAssigner::SplitBuffersByColor( } Status BufferAssigner::AssignBuffersWithSequentialOrdering( - const FlatMap>& + const flat_hash_map>& buffers_to_assign_sequentially, bool run_whole_module_heap_simulation, BufferAssignment* assignment) { // Run the sequence of instructions through the heap simulator. The heuristic // that seems to give the best results is lazy-best-fit, with all runs of // alloc / free calls sorted in decreasing size order. const HloOrdering& hlo_ordering = assignment->liveness().hlo_ordering(); + + // Returns a heap algorithm that chooses the best result from several + // algorithms. + auto get_heap_algorithm = [&](int64 alignment) { + auto algorithms = + absl::make_unique>>(); + algorithms->push_back(absl::make_unique( + absl::make_unique(alignment))); + algorithms->push_back( + absl::make_unique(alignment)); + return absl::make_unique(std::move(algorithms)); + }; + if (run_whole_module_heap_simulation) { // Run the heap simulation over the whole module. This reduces memory usage, // since buffers for kCall, kWhile, and kConditional sub-computations are // only live for the duration of their calling instructions. VLOG(1) << "Running whole-module heap simulation"; - SequentialHloOrdering::HloModuleSequence module_sequence; - FlatSet all_buffers_to_assign; + HloSchedule schedule(&assignment->module()); + flat_hash_set all_buffers_to_assign; for (const auto& pair : buffers_to_assign_sequentially) { const HloComputation* computation = pair.first; - const FlatSet& buffers_to_assign = pair.second; + const flat_hash_set& buffers_to_assign = + pair.second; const std::vector* instruction_sequence = hlo_ordering.SequentialOrder(*computation); CHECK(instruction_sequence != nullptr) << computation->name(); - module_sequence[computation] = *instruction_sequence; + schedule.set_sequence(computation, *instruction_sequence); all_buffers_to_assign.insert(buffers_to_assign.begin(), buffers_to_assign.end()); } @@ -1098,9 +1097,8 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering( options.buffers_to_assign = &buffer_value_set; TF_ASSIGN_OR_RETURN( const HeapSimulator::Result result, - HeapSimulator::Run(absl::make_unique( - absl::make_unique(alignment)), - assignment->module(), module_sequence, + HeapSimulator::Run(get_heap_algorithm(alignment), + assignment->module(), schedule, assignment->points_to_analysis(), assignment->buffer_size_, options)); AssignBuffersFromHeapSimulator(result, assignment, @@ -1113,7 +1111,8 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering( VLOG(1) << "Running per-computation heap simulation"; for (const auto& pair : buffers_to_assign_sequentially) { const HloComputation* computation = pair.first; - const FlatSet& buffers_to_assign = pair.second; + const flat_hash_set& buffers_to_assign = + pair.second; const std::vector* instruction_sequence = hlo_ordering.SequentialOrder(*computation); CHECK(instruction_sequence != nullptr) << computation->name(); @@ -1128,12 +1127,10 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering( options.buffers_to_assign = &buffer_value_set; TF_ASSIGN_OR_RETURN( const HeapSimulator::Result result, - HeapSimulator::Run( - absl::make_unique( - absl::make_unique(alignment)), - *computation, *instruction_sequence, - assignment->points_to_analysis(), assignment->buffer_size_, - options)); + HeapSimulator::Run(get_heap_algorithm(alignment), *computation, + HloInstructionSequence(*instruction_sequence), + assignment->points_to_analysis(), + assignment->buffer_size_, options)); AssignBuffersFromHeapSimulator(result, assignment, single_colored_set.first); } @@ -1150,9 +1147,8 @@ std::vector ComputePeakMemoryLogicalBuffers( const BufferAllocation& allocation, const HeapSimulatorTrace& heap_trace) { // Create a map from LogicalBuffer::Id to LogicalBuffer* for the logical // buffers in this allocation. - tensorflow::gtl::FlatMap - id_to_buffer; - tensorflow::gtl::FlatMap buffer_sizes; + absl::flat_hash_map id_to_buffer; + absl::flat_hash_map buffer_sizes; for (const auto& pair : allocation.assigned_buffers()) { const LogicalBuffer* buffer = pair.first; const BufferAllocation::OffsetSize& offset_size = pair.second; @@ -1191,7 +1187,7 @@ std::vector ComputePeakMemoryLogicalBuffers( // Next gather the set of logical buffers live at the earliest point of // maximal live set size. - tensorflow::gtl::FlatSet live_buffers; + absl::flat_hash_set live_buffers; live_size = 0; for (const auto& event : heap_trace.events()) { const LogicalBuffer* buffer = id_to_buffer.at(event.buffer_id()); @@ -1423,13 +1419,28 @@ BufferAssigner::MergeColocatedBufferSets( // Builds sets of buffers in 'colocated_buffer_sets' which should be colocated // in the same allocation (currently just supports kWhile, kCall, and -// kConditional). +// kConditional and input output aliasing). void BufferAssigner::BuildColocatedBufferSets( const HloModule* module, const BufferLiveness& buffer_liveness, const LogicalBuffer::SizeFunction& buffer_size, std::vector* colocated_buffer_sets) { const TuplePointsToAnalysis& points_to_analysis = buffer_liveness.points_to_analysis(); + + // Set up colocated buffer set for input and output. + module->input_output_alias_config().ForEachAlias( + [&](const ShapeIndex& output_index, int64 param_number, + const ShapeIndex& param_index) { + std::vector colocated_set; + AddBufferToColocatedSet(module->entry_computation()->root_instruction(), + output_index, points_to_analysis, + &colocated_set); + AddBufferToColocatedSet( + module->entry_computation()->parameter_instruction(param_number), + param_index, points_to_analysis, &colocated_set); + AddSetToColocatedBufferSets(colocated_set, colocated_buffer_sets); + }); + for (const HloComputation* computation : module->MakeComputationPostOrder()) { if (computation->IsFusionComputation()) { continue; @@ -1581,8 +1592,8 @@ void BufferAssigner::BuildColocatedBufferSets( void BufferAssigner::AssignColocatedBufferSets( const std::vector& colocated_buffer_sets, BufferAssignment* assignment, - FlatSet* colocated_buffers, - FlatSet* colocated_allocations) { + flat_hash_set* colocated_buffers, + flat_hash_set* colocated_allocations) { for (const ColocatedBufferSet& colocated_buffer_set : colocated_buffer_sets) { BufferAllocation* allocation = nullptr; // Set 'entry_parameter_number' and 'entry_parameter_shape_idx' if entry @@ -1655,8 +1666,8 @@ StatusOr> BufferAssigner::CreateAssignment( // Once b/32491382 enables module-level liveness analysis, we may be able // to assign colocated buffers (or at least reuse their allocation for // buffers outside of the set) in AssignBuffersForComputation. - FlatSet colocated_buffers; - FlatSet colocated_allocations; + flat_hash_set colocated_buffers; + flat_hash_set colocated_allocations; std::vector colocated_buffer_sets; BuildColocatedBufferSets(module, assignment->liveness(), assignment->buffer_size_, &colocated_buffer_sets); @@ -1674,7 +1685,7 @@ StatusOr> BufferAssigner::CreateAssignment( // First assign buffers for global computatations. Temporary buffers for // sequential computations are collected in 'buffers_to_assign_sequentially'. - FlatMap> + flat_hash_map> buffers_to_assign_sequentially; for (auto* computation : global_computations) { TF_RETURN_IF_ERROR(AssignBuffersForComputation( diff --git a/tensorflow/compiler/xla/service/buffer_assignment.h b/tensorflow/compiler/xla/service/buffer_assignment.h index 94495290c131e22392079dc2d0237d990b646d3e..899cd36e1f98c9e7b8ba7e42c06ced5c3e8afcc8 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.h +++ b/tensorflow/compiler/xla/service/buffer_assignment.h @@ -22,6 +22,9 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/service/buffer_liveness.h" #include "tensorflow/compiler/xla/service/heap_simulator.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" @@ -32,15 +35,23 @@ limitations under the License. #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/lib/gtl/flatmap.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" namespace xla { +// Walk the call graph of the HLO module and place each computation into either +// thread_local_computations or global_computations depending upon whether the +// computation requires thread-local allocations or global allocations. The +// elements in thread_local_computations and global_computations are in post +// order (if computation A has an instruction which calls computation B, then A +// will appear after B in the vector). +Status GatherComputationsByAllocationType( + const HloModule* module, + std::vector* thread_local_computations, + std::vector* global_computations); + // This class abstracts an allocation of contiguous memory which can hold the // values described by LogicalBuffers. Each LogicalBuffer occupies a sub-range // of the allocation, represented by a Slice. A single BufferAllocation may hold @@ -137,7 +148,7 @@ class BufferAllocation { // Access to the logical buffers assigned to this allocation, and their // associated logical offsets and sizes. - const tensorflow::gtl::FlatMap& + const absl::flat_hash_map& assigned_buffers() const { return assigned_buffers_; } @@ -312,7 +323,7 @@ class BufferAllocation { // Mapping from the set of buffers assigned to this allocation to their // logical offsets and sizes. - tensorflow::gtl::FlatMap assigned_buffers_; + absl::flat_hash_map assigned_buffers_; int64 fragmentation_bytes_ = 0; std::vector heap_traces_; @@ -489,7 +500,7 @@ class BufferAssignment { int64 temp_allocation_total_size_ = 0; // Maps Buffers to the index of the BufferAllocation which holds the buffer. - tensorflow::gtl::FlatMap + absl::flat_hash_map allocation_index_for_buffer_; const HloModule* module_; @@ -543,11 +554,10 @@ class BufferAssigner { // true. Status AssignBuffersForComputation( const HloComputation* computation, bool is_thread_local, - const tensorflow::gtl::FlatSet& colocated_buffers, - const tensorflow::gtl::FlatSet& - colocated_allocations, - tensorflow::gtl::FlatMap>* + const absl::flat_hash_set& colocated_buffers, + const absl::flat_hash_set& colocated_allocations, + absl::flat_hash_map>* buffers_to_assign_sequentially, BufferAssignment* assignment); @@ -557,9 +567,8 @@ class BufferAssigner { // 'run_whole_module_heap_simulation' is true, the heap simulation will be run // assuming all global computations are sequentially ordered. Status AssignBuffersWithSequentialOrdering( - const tensorflow::gtl::FlatMap< - const HloComputation*, - tensorflow::gtl::FlatSet>& + const absl::flat_hash_map>& buffers_to_assign_sequentially, bool run_whole_module_heap_simulation, BufferAssignment* assignment); @@ -579,7 +588,7 @@ class BufferAssigner { // alias. Explicitly handling these colocated buffers is necessary because // points-to analysis is computation level scope and does not recognize // aliasing across computations (b/32491382). - using ColocatedBufferSet = tensorflow::gtl::FlatSet; + using ColocatedBufferSet = absl::flat_hash_set; // Returns a vector of ColocatedBufferSet objects, where each // ColocatedBufferSet aggregates a set of related LogicalBuffers from 'module' @@ -594,8 +603,8 @@ class BufferAssigner { void AssignColocatedBufferSets( const std::vector& colocated_buffer_sets, BufferAssignment* assignment, - tensorflow::gtl::FlatSet* colocated_buffers, - tensorflow::gtl::FlatSet* colocated_allocations); + absl::flat_hash_set* colocated_buffers, + absl::flat_hash_set* colocated_allocations); // Adds the 'colocated_set' of buffers to 'colocated_buffer_sets', maintaining // the invariant that all sets in 'colocated_buffer_sets' are disjoint. @@ -613,11 +622,10 @@ class BufferAssigner { // Split a set of buffers into several sets, each of which contains buffers // colored with the same color. - tensorflow::gtl::FlatMap, - LogicalBuffer::Color::Hasher> - SplitBuffersByColor( - const tensorflow::gtl::FlatSet& buffers); + absl::flat_hash_map, + LogicalBuffer::Color::Hasher> + SplitBuffersByColor(const absl::flat_hash_set& buffers); // If true, buffer assignments assumes that input parameter buffers and output // buffers can be shared if their sizes match. diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc index 52abda16c4ee8e494b596e0690a8067743380054..795beb9ff5ceb2998a85fbd03d8bb1d3b2febc12 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc @@ -30,16 +30,18 @@ limitations under the License. #include "tensorflow/compiler/xla/service/flatten_call_graph.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" -#include "tensorflow/compiler/xla/service/hlo_scheduling.h" +#include "tensorflow/compiler/xla/service/hlo_schedule.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/macros.h" namespace xla { @@ -79,9 +81,8 @@ const std::vector GetInstructions(HloInstruction* root) { return main_list.GetInstructions(); } -class BufferAssignmentTest : public HloTestBase { +class BufferAssignmentTest : public HloVerifiedTestBase { protected: - BufferAssignmentTest() {} ~BufferAssignmentTest() override {} std::unique_ptr RunBufferAssignment(HloModule* module, @@ -119,16 +120,12 @@ class BufferAssignmentTest : public HloTestBase { std::unique_ptr RunBufferAssignmentWithInstructionSequence( HloModule* module, - tensorflow::gtl::ArraySlice instruction_sequence, + absl::Span instruction_sequence, int64 alignment = 1) { - SequentialHloOrdering::HloModuleSequence module_sequence; - module_sequence[module->entry_computation()] = - std::vector(instruction_sequence.begin(), - instruction_sequence.end()); + HloSchedule schedule(module); + schedule.set_sequence(module->entry_computation(), instruction_sequence); return BufferAssigner::Run( - module, - absl::make_unique(module, - module_sequence), + module, absl::make_unique(schedule), backend().compiler()->BufferSizeBytesFunction(), [alignment](LogicalBuffer::Color) { return alignment; }, /*allow_input_output_aliasing=*/false, @@ -148,6 +145,17 @@ class BufferAssignmentTest : public HloTestBase { return builder.Build(); } + std::unique_ptr BuildReduceComputation(const string& name) { + auto builder = HloComputation::Builder(name); + auto param = + builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "x")); + auto param2 = + builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32_, "y")); + builder.AddInstruction( + HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, param, param2)); + return builder.Build(); + } + // Builds a simple compare-to-limit (x < 4) computation for a While. // // condition: @@ -164,8 +172,8 @@ class BufferAssignmentTest : public HloTestBase { HloInstruction::CreateParameter(0, t_s32_f32v4_, "x")); auto index = builder.AddInstruction( HloInstruction::CreateGetTupleElement(const4->shape(), param, 0)); - builder.AddInstruction( - HloInstruction::CreateBinary(r0f32_, HloOpcode::kLt, index, const4)); + builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(PRED, {}), HloOpcode::kLt, index, const4)); return builder.Build(); } @@ -312,12 +320,12 @@ TEST_F(BufferAssignmentTest, ScalarConstant) { module->AddEntryComputation(builder.Build()); { - auto buffers = RunBufferAssignment(module.get()); + auto buffers = RunBufferAssignment(module); EXPECT_TRUE(buffers->HasTopLevelAllocation(const0)); } { - auto buffers = RunBufferAssignmentNoBuffersForConstants(module.get()); + auto buffers = RunBufferAssignmentNoBuffersForConstants(module); EXPECT_FALSE(buffers->HasTopLevelAllocation(const0)); } } @@ -336,13 +344,13 @@ TEST_F(BufferAssignmentTest, BufferForConst) { module->AddEntryComputation(builder.Build()); { - auto buffers = RunBufferAssignment(module.get()); + auto buffers = RunBufferAssignment(module); EXPECT_TRUE(buffers->HasTopLevelAllocation(const0)); EXPECT_TRUE(buffers->HasTopLevelAllocation(const1)); GetAssignedOutputAllocation(*buffers, add); } { - auto buffers = RunBufferAssignmentNoBuffersForConstants(module.get()); + auto buffers = RunBufferAssignmentNoBuffersForConstants(module); EXPECT_FALSE(buffers->HasTopLevelAllocation(const0)); EXPECT_FALSE(buffers->HasTopLevelAllocation(const1)); GetAssignedOutputAllocation(*buffers, add); @@ -364,7 +372,7 @@ TEST_F(BufferAssignmentTest, HasAllocationAt) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - auto buffers = RunBufferAssignment(module.get()); + auto buffers = RunBufferAssignment(module); // Make sure that HasAllocationAt() agrees with what HasTopLevelAllocation() // reports for the instruction directly. EXPECT_EQ(buffers->HasTopLevelAllocation(tuple), @@ -387,7 +395,7 @@ TEST_F(BufferAssignmentTest, BufferForOutputConst) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - auto buffers = RunBufferAssignment(module.get()); + auto buffers = RunBufferAssignment(module); // The copy node now has an output buffer. GetAssignedOutputAllocation(*buffers, copy); } @@ -401,12 +409,14 @@ TEST_F(BufferAssignmentTest, Basic) { auto builder = HloComputation::Builder(TestName()); auto paramscalar = builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p")); + auto broadcast = builder.AddInstruction( + HloInstruction::CreateBroadcast(f32vec100_, paramscalar, {})); auto param0 = builder.AddInstruction( HloInstruction::CreateParameter(1, f32vec100_, "p1")); auto param1 = builder.AddInstruction( HloInstruction::CreateParameter(2, f32vec100_, "p2")); auto mul = builder.AddInstruction(HloInstruction::CreateBinary( - f32vec100_, HloOpcode::kMultiply, paramscalar, param0)); + f32vec100_, HloOpcode::kMultiply, broadcast, param0)); auto add = builder.AddInstruction( HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1)); auto sub = builder.AddInstruction(HloInstruction::CreateBinary( @@ -414,7 +424,7 @@ TEST_F(BufferAssignmentTest, Basic) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - auto buffers = RunBufferAssignment(module.get()); + auto buffers = RunBufferAssignment(module); // Distinct input buffers were assigned for parameters. BufferAllocation paramscalar_buffer = @@ -448,12 +458,14 @@ TEST_F(BufferAssignmentTest, BasicUniquelyColored) { auto builder = HloComputation::Builder(TestName()); auto paramscalar = builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p")); + auto broadcast = builder.AddInstruction( + HloInstruction::CreateBroadcast(f32vec100_, paramscalar, {})); auto param0 = builder.AddInstruction( HloInstruction::CreateParameter(1, f32vec100_, "p1")); auto param1 = builder.AddInstruction( HloInstruction::CreateParameter(2, f32vec100_, "p2")); auto mul = builder.AddInstruction(HloInstruction::CreateBinary( - f32vec100_, HloOpcode::kMultiply, paramscalar, param0)); + f32vec100_, HloOpcode::kMultiply, broadcast, param0)); auto add = builder.AddInstruction( HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1)); auto sub = builder.AddInstruction(HloInstruction::CreateBinary( @@ -473,7 +485,7 @@ TEST_F(BufferAssignmentTest, BasicUniquelyColored) { return Status::OK(); }; - auto buffers = RunColoredBufferAssignment(module.get(), colorer); + auto buffers = RunColoredBufferAssignment(module, colorer); // Distinct input buffers were assigned for parameters. BufferAllocation paramscalar_buffer = @@ -507,12 +519,14 @@ TEST_F(BufferAssignmentTest, BasicPartiallyColored) { auto builder = HloComputation::Builder(TestName()); auto paramscalar = builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p")); + auto broadcast = builder.AddInstruction( + HloInstruction::CreateBroadcast(f32vec100_, paramscalar, {})); auto param0 = builder.AddInstruction( HloInstruction::CreateParameter(1, f32vec100_, "p1")); auto param1 = builder.AddInstruction( HloInstruction::CreateParameter(2, f32vec100_, "p2")); auto mul = builder.AddInstruction(HloInstruction::CreateBinary( - f32vec100_, HloOpcode::kMultiply, paramscalar, param0)); + f32vec100_, HloOpcode::kMultiply, broadcast, param0)); auto add = builder.AddInstruction( HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1)); auto sub = builder.AddInstruction(HloInstruction::CreateBinary( @@ -540,7 +554,7 @@ TEST_F(BufferAssignmentTest, BasicPartiallyColored) { return Status::OK(); }; - auto buffers = RunColoredBufferAssignment(module.get(), colorer); + auto buffers = RunColoredBufferAssignment(module, colorer); // Distinct input buffers were assigned for parameters. BufferAllocation paramscalar_buffer = @@ -577,12 +591,14 @@ TEST_F(BufferAssignmentTest, MultipleUsersForNode) { auto builder = HloComputation::Builder(TestName()); auto paramscalar = builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p")); + auto broadcast = builder.AddInstruction( + HloInstruction::CreateBroadcast(f32vec100_, paramscalar, {})); auto param0 = builder.AddInstruction( HloInstruction::CreateParameter(1, f32vec100_, "p1")); auto param1 = builder.AddInstruction( HloInstruction::CreateParameter(2, f32vec100_, "p2")); auto mul = builder.AddInstruction(HloInstruction::CreateBinary( - f32vec100_, HloOpcode::kMultiply, paramscalar, param0)); + f32vec100_, HloOpcode::kMultiply, broadcast, param0)); auto add = builder.AddInstruction( HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1)); auto sub = builder.AddInstruction( @@ -590,7 +606,7 @@ TEST_F(BufferAssignmentTest, MultipleUsersForNode) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - auto buffers = RunBufferAssignment(module.get()); + auto buffers = RunBufferAssignment(module); // Input buffers were assigned for parameters. BufferAllocation paramscalar_buffer = @@ -641,7 +657,7 @@ TEST_F(BufferAssignmentTest, TrivialMap) { EXPECT_EQ(3, level1.size()) << "Invalid nested add+1 size"; // Assigns buffers and fetches sizes. - auto buffers = RunBufferAssignment(module.get()); + auto buffers = RunBufferAssignment(module); int64 size0 = ValidateBuffers(level0, *buffers); int64 size1 = ValidateBuffers(level1, *buffers); @@ -676,10 +692,10 @@ TEST_F(BufferAssignmentTest, CannotReuseInputBufferOfReduce) { // output. (Reuse is not safe in the general case, as it reshapes and some // out-of-order reductions could overwrite an element before a use.) // - // param0[100] --- (exp1) --- (exp2) --- (reduce x+1) --- (exp3) + // param0[100] --- (exp1) --- (exp2) --- (reduce x+y) --- (exp3) auto module = CreateNewModule(); auto reduce_computation = - module->AddEmbeddedComputation(BuildMapComputationPlus1("f32+1")); + module->AddEmbeddedComputation(BuildReduceComputation("f32+f32")); auto builder = HloComputation::Builder(TestName()); auto param0 = builder.AddInstruction( @@ -700,7 +716,7 @@ TEST_F(BufferAssignmentTest, CannotReuseInputBufferOfReduce) { module->AddEntryComputation(builder.Build()); - auto buffers = RunBufferAssignment(module.get()); + auto buffers = RunBufferAssignment(module); const std::vector instrs = GetInstructions(exp3); ValidateBuffers(instrs, *buffers); @@ -756,7 +772,7 @@ TEST_F(BufferAssignmentTest, ExampleWhile) { EXPECT_EQ(8, levelb.size()) << "Invalid nested body size"; // Assigns buffers and fetches sizes. - auto buffers = RunBufferAssignment(module.get()); + auto buffers = RunBufferAssignment(module); int64 size0 = ValidateBuffers(level0, *buffers); int64 sizec = ValidateBuffers(levelc, *buffers); int64 sizeb = ValidateBuffers(levelb, *buffers); @@ -821,7 +837,7 @@ TEST_F(BufferAssignmentTest, ExampleConditional) { EXPECT_EQ(2, true_instrs.size()); EXPECT_EQ(2, false_instrs.size()); - auto buffers = RunBufferAssignment(module.get()); + auto buffers = RunBufferAssignment(module); ValidateBuffers(conditional_instrs, *buffers); ValidateBuffers(true_instrs, *buffers); ValidateBuffers(false_instrs, *buffers); @@ -859,7 +875,7 @@ TEST_F(BufferAssignmentTest, UnaryOpReuseChain) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module.get()); + auto assignment = RunBufferAssignment(module); // tanh and exp2 can reuse exp1's buffer EXPECT_TRUE(assignment->HasTopLevelAllocation(exp1)); @@ -888,7 +904,7 @@ TEST_F(BufferAssignmentTest, ReuseNonOperandBuffer) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module.get()); + auto assignment = RunBufferAssignment(module); // negate and broadcast should share a buffer. EXPECT_TRUE(assignment->HasTopLevelAllocation(broadcast)); @@ -921,7 +937,7 @@ TEST_F(BufferAssignmentTest, NoReuseLiveBuffer) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module.get()); + auto assignment = RunBufferAssignment(module); // The instructions should not share buffers. EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast), @@ -958,7 +974,7 @@ TEST_F(BufferAssignmentTest, NoReuseAliasedBuffer) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module.get()); + auto assignment = RunBufferAssignment(module); // The instructions should not share buffers. EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast), @@ -993,7 +1009,7 @@ TEST_F(BufferAssignmentTest, DoNotReuseOversizedOutputBuffer) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module.get()); + auto assignment = RunBufferAssignment(module); // The broadcast output buffer cannot be shared. EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast), @@ -1025,7 +1041,7 @@ TEST_F(BufferAssignmentTest, ReuseOutputBufferIfExactlySized) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module.get()); + auto assignment = RunBufferAssignment(module); // negate and broadcast should share a buffer. EXPECT_TRUE(assignment->HasTopLevelAllocation(broadcast)); @@ -1063,7 +1079,7 @@ TEST_F(BufferAssignmentTest, DoNotReuseOversizedOutputBufferInTuple) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module.get()); + auto assignment = RunBufferAssignment(module); // The broadcast output buffer cannot be shared. EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast), @@ -1107,7 +1123,7 @@ TEST_F(BufferAssignmentTest, EmbeddedComputationBuffers) { HloInstruction::CreateMap(vec_shape, {call}, map_computation)); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module.get()); + auto assignment = RunBufferAssignment(module); // Allocations for the map computation should be thread-local and not // live-out. @@ -1156,7 +1172,7 @@ TEST_F(BufferAssignmentTest, TupleParameterAsOutput) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module.get()); + auto assignment = RunBufferAssignment(module); // There should be four allocations: one for vector of pointers, and one for // each tuple element. @@ -1192,7 +1208,7 @@ TEST_F(BufferAssignmentTest, ElementOfNestedTupleParameterAsOutput) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module.get()); + auto assignment = RunBufferAssignment(module); // Only some of the elements of the input param are liveout. EXPECT_FALSE( @@ -1229,13 +1245,14 @@ TEST_F(BufferAssignmentTest, TupleConstantAsOutput) { // Test that a tuple constant which is forwarded to the computation output // is properly handled. auto builder = HloComputation::Builder(TestName()); + Literal elements[] = {LiteralUtil::CreateR0(0), + LiteralUtil::CreateR0(1)}; builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::MakeTuple({LiteralUtil::CreateR0(0).get(), - LiteralUtil::CreateR0(1).get()}))); + LiteralUtil::MakeTuple({&elements[0], &elements[1]}))); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module.get()); + auto assignment = RunBufferAssignment(module); EXPECT_EQ(3, assignment->Allocations().size()); } @@ -1249,7 +1266,7 @@ TEST_F(BufferAssignmentTest, TupleCustomCallAsOutput) { /*operands=*/{}, /*custom_call_target=*/"foo_function")); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module.get()); + auto assignment = RunBufferAssignment(module); EXPECT_EQ(3, assignment->Allocations().size()); EXPECT_TRUE( @@ -1280,7 +1297,7 @@ TEST_F(BufferAssignmentTest, TupleCallAsOutput) { HloInstruction::CreateCall(tuple_shape, {param}, sub_computation)); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module.get()); + auto assignment = RunBufferAssignment(module); EXPECT_EQ(2, assignment->Allocations().size()); // Buffers for call are colocated with the sub-computation. @@ -1342,7 +1359,7 @@ TEST_F(BufferAssignmentTest, TupleChainedCallAsOutput) { module->AddEntryComputation(std::move(a_computation)); module->AddEmbeddedComputation(std::move(b_computation)); - auto assignment = RunBufferAssignment(module.get()); + auto assignment = RunBufferAssignment(module); // Buffers for call are colocated with the sub-computations. EXPECT_EQ(GetAllocation(*assignment, a_call, /*index=*/{}), @@ -1378,7 +1395,7 @@ TEST_F(BufferAssignmentTest, BitcastAsOutput) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module.get()); + auto assignment = RunBufferAssignment(module); // Bitcast should get the same allocation as the param. EXPECT_EQ(1, assignment->Allocations().size()); @@ -1405,7 +1422,7 @@ TEST_F(BufferAssignmentTest, AmbiguousBufferAsOutput) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module.get()); + auto assignment = RunBufferAssignment(module); // Select shallow copies one of its operands so it defines its own top-level // buffer and receives its own allocation. @@ -1443,7 +1460,7 @@ TEST_F(BufferAssignmentTest, TupleBufferNotReused) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module.get()); + auto assignment = RunBufferAssignment(module); // There should be no buffer reuse. The copy should not reuse the tuple // buffer. @@ -1472,17 +1489,20 @@ TEST_F(BufferAssignmentTest, OneTempAllocation) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - auto dot_ab = builder.AddInstruction( - HloInstruction::CreateDot(shape_2x4, param_a, param_b, dot_dnums)); - auto dot_bc = builder.AddInstruction( - HloInstruction::CreateDot(shape_3x4, param_b, param_c, dot_dnums)); + PrecisionConfig precision_config; + precision_config.mutable_operand_precision()->Resize( + 2, PrecisionConfig::DEFAULT); + auto dot_ab = builder.AddInstruction(HloInstruction::CreateDot( + shape_2x4, param_a, param_b, dot_dnums, precision_config)); + auto dot_bc = builder.AddInstruction(HloInstruction::CreateDot( + shape_3x4, param_b, param_c, dot_dnums, precision_config)); builder.AddInstruction( - HloInstruction::CreateConcatenate(shape_5x4, {dot_ab, dot_bc}, 1)); + HloInstruction::CreateConcatenate(shape_5x4, {dot_ab, dot_bc}, 0)); // Run buffer assignment with alignment=1. auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module.get(), /*alignment=*/1); + auto assignment = RunBufferAssignment(module, /*alignment=*/1); // There are 5 allocations: 3 parameters, 1 output, and 1 temp. EXPECT_EQ(5, assignment->Allocations().size()); @@ -1501,7 +1521,7 @@ TEST_F(BufferAssignmentTest, OneTempAllocation) { EXPECT_EQ(80, slice_bc.allocation()->size()); // Re-run buffer assignment with alignment=64. - assignment = RunBufferAssignment(module.get(), /*alignment=*/64); + assignment = RunBufferAssignment(module, /*alignment=*/64); EXPECT_EQ(5, assignment->Allocations().size()); slice_ab = assignment->GetUniqueTopLevelSlice(dot_ab).ConsumeValueOrDie(); slice_bc = assignment->GetUniqueTopLevelSlice(dot_bc).ConsumeValueOrDie(); @@ -1532,12 +1552,14 @@ TEST_F(BufferAssignmentTest, TrivialPeakBuffers) { auto builder = HloComputation::Builder(TestName()); auto paramscalar = builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p")); + auto broadcast = builder.AddInstruction( + HloInstruction::CreateBroadcast(f32vec100_, paramscalar, {})); auto param0 = builder.AddInstruction( HloInstruction::CreateParameter(1, f32vec100_, "p1")); auto param1 = builder.AddInstruction( HloInstruction::CreateParameter(2, f32vec100_, "p2")); auto mul = builder.AddInstruction(HloInstruction::CreateBinary( - f32vec100_, HloOpcode::kMultiply, paramscalar, param0)); + f32vec100_, HloOpcode::kMultiply, broadcast, param0)); auto add = builder.AddInstruction( HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1)); builder.AddInstruction(HloInstruction::CreateBinary( @@ -1545,16 +1567,13 @@ TEST_F(BufferAssignmentTest, TrivialPeakBuffers) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - auto buffers = RunBufferAssignment(module.get()); + auto buffers = RunBufferAssignment(module); - // Trivially, the set of peak memory logical buffer(s) of an allocation with a - // single logical buffer should be exactly the logical buffer in that - // allocation. const BufferAllocation& mul_buffer = GetTopLevelAllocation(*buffers, mul); const std::vector& peak_buffers = mul_buffer.PeakMemoryLogicalBuffers(); ASSERT_EQ(peak_buffers.size(), 1); - EXPECT_EQ(peak_buffers[0]->instruction(), mul); + EXPECT_EQ(peak_buffers[0]->instruction(), broadcast); } TEST_F(BufferAssignmentTest, PeakBuffers) { @@ -1590,7 +1609,7 @@ TEST_F(BufferAssignmentTest, PeakBuffers) { module->AddEntryComputation(builder.Build()); auto buffers = RunBufferAssignmentWithInstructionSequence( - module.get(), {param, log, rev, neg, concat, root}); + module, {param, log, rev, neg, concat, root}); // The temporary buffer should hold the 4 interior instructions. const BufferAllocation& buffer = GetTopLevelAllocation(*buffers, concat); @@ -1646,7 +1665,7 @@ TEST_F(BufferAssignmentTest, PeakBuffersWhile) { ShapeUtil::MakeShape(F32, {123, 123, 123}), bcast, {0})); module->AddEntryComputation(builder.Build()); - auto buffers = RunBufferAssignment(module.get()); + auto buffers = RunBufferAssignment(module); const BufferAllocation& buffer = GetTopLevelAllocation(*buffers, bcast); const std::vector& peak_buffers = buffer.PeakMemoryLogicalBuffers(); @@ -1696,15 +1715,13 @@ ENTRY main { } )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseHloString(hlo_text)); - + ParseAndVerifyModule(hlo_text); HloInstruction* constant_1 = - module->entry_computation()->GetInstructionWithName("constant.1.1"); + module().entry_computation()->GetInstructionWithName("constant.1.1"); HloInstruction* constant_2 = - module->entry_computation()->GetInstructionWithName("constant.1.2"); + module().entry_computation()->GetInstructionWithName("constant.1.2"); - auto buffers = RunBufferAssignment(module.get()); + auto buffers = RunBufferAssignment(&module()); { const BufferAllocation& allocation_for_const_1 = @@ -1733,7 +1750,7 @@ ENTRY main { } } -class WhileBufferAssignmentTest : public HloTestBase { +class WhileBufferAssignmentTest : public HloVerifiedTestBase { protected: std::unique_ptr BuildWhileConditionComputation( const string& name) { @@ -1767,11 +1784,10 @@ class WhileBufferAssignmentTest : public HloTestBase { std::unique_ptr RunBufferAssignment(HloModule* module, int64 alignment = 1) { - auto sequence = - ScheduleComputationsInModule(*module, ByteSizeOf).ConsumeValueOrDie(); + HloSchedule schedule = + ScheduleModule(*module, ByteSizeOf).ConsumeValueOrDie(); return BufferAssigner::Run( - module, - absl::make_unique(module, sequence), + module, absl::make_unique(schedule), ByteSizeOf, [alignment](LogicalBuffer::Color) { return alignment; }, /*allow_input_output_aliasing=*/false, @@ -1807,9 +1823,9 @@ TEST_F(WhileBufferAssignmentTest, TwoForwardWhileLoops) { auto zero = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0))); auto output0 = builder.AddInstruction( - HloInstruction::CreateBroadcast(data_shape_, zero, {1})); + HloInstruction::CreateBroadcast(data_shape_, zero, {})); auto output1 = builder.AddInstruction( - HloInstruction::CreateBroadcast(data_shape_, zero, {1})); + HloInstruction::CreateBroadcast(data_shape_, zero, {})); auto cond0 = module->AddEmbeddedComputation(BuildWhileConditionComputation("cond")); @@ -1833,8 +1849,8 @@ TEST_F(WhileBufferAssignmentTest, TwoForwardWhileLoops) { HloInstruction::CreateWhile(loop_state_shape_, cond1, body1, tuple1)); module->AddEntryComputation(builder.Build()); - RunCopyInsertion(module.get()); - auto assignment = RunBufferAssignment(module.get()); + RunCopyInsertion(module); + auto assignment = RunBufferAssignment(module); // Verify 'input0' and read-only use while0{0} alias. EXPECT_EQ(assignment->GetUniqueSlice(input0, {}).ConsumeValueOrDie(), @@ -1890,20 +1906,20 @@ ENTRY %test_module { ROOT %bcast = s32[1024,1024]{1,0} broadcast(s32[] %while.1), dimensions={} })"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseHloString(module_str)); + ParseAndVerifyModule(module_str); // Run CopyInsertion and check if the graph constructed above doesn't need // any copies inserted for BufferAssignment to run. - int64 instruction_count = module->instruction_count(); + int64 instruction_count = module().instruction_count(); CopyInsertion copy_insertion; - ASSERT_IS_OK(copy_insertion.Run(module.get()).status()); - ASSERT_EQ(instruction_count, module->instruction_count()); + ASSERT_IS_OK(copy_insertion.Run(&module()).status()); + ASSERT_EQ(instruction_count, module().instruction_count()); // Get the instructions in the module. - const HloInstruction* bcast = module->entry_computation()->root_instruction(); + const HloInstruction* bcast = + module().entry_computation()->root_instruction(); const HloInstruction* param = - module->entry_computation()->parameter_instruction(0); + module().entry_computation()->parameter_instruction(0); ASSERT_EQ(bcast->opcode(), HloOpcode::kBroadcast); const HloInstruction* while1 = bcast->operand(0); ASSERT_EQ(while1->opcode(), HloOpcode::kWhile); @@ -1911,7 +1927,7 @@ ENTRY %test_module { ASSERT_EQ(while0->opcode(), HloOpcode::kWhile); // Run buffer assignment. - auto assignment = RunBufferAssignment(module.get()); + auto assignment = RunBufferAssignment(&module()); TF_ASSERT_OK_AND_ASSIGN(auto slice_param, assignment->GetUniqueSlice(param, {})); TF_ASSERT_OK_AND_ASSIGN(auto slice_while0, @@ -1958,20 +1974,20 @@ ENTRY %test_module { ROOT %bcast = s32[1024,1024]{1,0} broadcast(s32[] %while.1), dimensions={} })"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseHloString(module_str)); + ParseAndVerifyModule(module_str); // Run CopyInsertion and check if the graph constructed above doesn't need // any copies inserted for BufferAssignment to run. - int64 instruction_count = module->instruction_count(); + int64 instruction_count = module().instruction_count(); CopyInsertion copy_insertion; - ASSERT_IS_OK(copy_insertion.Run(module.get()).status()); - ASSERT_EQ(instruction_count, module->instruction_count()); + ASSERT_IS_OK(copy_insertion.Run(&module()).status()); + ASSERT_EQ(instruction_count, module().instruction_count()); // Get the instructions in the module. - const HloInstruction* bcast = module->entry_computation()->root_instruction(); + const HloInstruction* bcast = + module().entry_computation()->root_instruction(); const HloInstruction* constant = - module->entry_computation()->GetInstructionWithName("constant.42"); + module().entry_computation()->GetInstructionWithName("constant.42"); ASSERT_EQ(bcast->opcode(), HloOpcode::kBroadcast); const HloInstruction* while1 = bcast->operand(0); ASSERT_EQ(while1->opcode(), HloOpcode::kWhile); @@ -1979,7 +1995,7 @@ ENTRY %test_module { ASSERT_EQ(while0->opcode(), HloOpcode::kWhile); // Run buffer assignment. - auto assignment = RunBufferAssignment(module.get()); + auto assignment = RunBufferAssignment(&module()); TF_ASSERT_OK_AND_ASSIGN(auto slice_constant, assignment->GetUniqueSlice(constant, {})); TF_ASSERT_OK_AND_ASSIGN(auto slice_while0, @@ -2072,24 +2088,31 @@ TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) { // any copies inserted for BufferAssignment to run. int64 instruction_count = module->instruction_count(); CopyInsertion copy_insertion; - ASSERT_IS_OK(copy_insertion.Run(module.get()).status()); + ASSERT_IS_OK(copy_insertion.Run(module).status()); ASSERT_EQ(instruction_count, module->instruction_count()); // Create a sequential order among all the instructions in the entry // computation, since the issue this test stresses depends on the order the // nodes are traversed during BufferAssignment. - SequentialHloOrdering::HloModuleSequence sequence; - sequence[module->entry_computation()] = { - token, infeed, infeed_data, while0, while1, zero, add, while2, tuple}; + TF_ASSERT_OK_AND_ASSIGN( + HloSchedule schedule, + ScheduleModule(*module, [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape(), + /*pointer_size=*/sizeof(void*)); + })); + schedule.set_sequence( + module->entry_computation(), + {token, infeed, infeed_data, while0, while1, zero, add, while2, tuple}); + TF_ASSERT_OK(schedule.Verify()); + TF_ASSERT_OK_AND_ASSIGN( auto assignment, - BufferAssigner::Run( - module.get(), - absl::make_unique(module.get(), sequence), - backend().compiler()->BufferSizeBytesFunction(), - [](LogicalBuffer::Color) { return 1; }, - /*allow_input_output_aliasing=*/false, - /*allocate_buffers_for_constants=*/true)); + BufferAssigner::Run(module, + absl::make_unique(schedule), + backend().compiler()->BufferSizeBytesFunction(), + [](LogicalBuffer::Color) { return 1; }, + /*allow_input_output_aliasing=*/false, + /*allocate_buffers_for_constants=*/true)); // The result tuple elements must be assigned with different buffers. TF_ASSERT_OK_AND_ASSIGN(auto slice0, assignment->GetUniqueSlice(tuple, {0})); @@ -2122,7 +2145,7 @@ TEST_F(WhileBufferAssignmentTest, OneForwardBackwardWhileLoopSet) { auto zero = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0))); auto output0 = builder.AddInstruction( - HloInstruction::CreateBroadcast(data_shape_, zero, {1})); + HloInstruction::CreateBroadcast(data_shape_, zero, {})); auto cond0 = module->AddEmbeddedComputation(BuildWhileConditionComputation("cond")); @@ -2143,8 +2166,8 @@ TEST_F(WhileBufferAssignmentTest, OneForwardBackwardWhileLoopSet) { HloInstruction::CreateWhile(loop_state_shape_, cond1, body1, while0)); module->AddEntryComputation(builder.Build()); - RunCopyInsertion(module.get()); - auto assignment = RunBufferAssignment(module.get()); + RunCopyInsertion(module); + auto assignment = RunBufferAssignment(module); // while0 and while1 buffers should be completely aligned. EXPECT_EQ(assignment->GetUniqueSlice(while0, {0}).ConsumeValueOrDie(), @@ -2186,13 +2209,13 @@ TEST_F(BufferAssignmentTest, TwoCalls) { { FlattenCallGraph flatten; - TF_ASSERT_OK_AND_ASSIGN(bool result, flatten.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool result, flatten.Run(module)); EXPECT_TRUE(result); - std::unique_ptr call_graph = CallGraph::Build(module.get()); + std::unique_ptr call_graph = CallGraph::Build(module); } - RunCopyInsertion(module.get()); - auto assignment = RunBufferAssignment(module.get()); + RunCopyInsertion(module); + auto assignment = RunBufferAssignment(module); EXPECT_TRUE(BuffersDistinct({call1}, {call2}, *assignment)); } @@ -2216,15 +2239,14 @@ ENTRY Main { } )"; - TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr module, - HloRunner::CreateModuleFromString( - hlo_text, legacy_flags::GetDebugOptionsFromFlags())); + HloModuleConfig config; + config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); + ParseAndVerifyModule(hlo_text, config); - auto buffers = RunBufferAssignment(module.get()); + auto buffers = RunBufferAssignment(&module()); - HloComputation* main = module->entry_computation(); - HloComputation* callee = module->GetComputationWithName("Callee"); + HloComputation* main = module().entry_computation(); + HloComputation* callee = module().GetComputationWithName("Callee"); EXPECT_NE(callee, nullptr); HloInstruction* param0 = callee->parameter_instruction(0); @@ -2247,29 +2269,6 @@ ENTRY Main { GetAllocation(*buffers, param0, {1, 1})); } -static bool IsPostOrderTraversal( - const std::vector& sequence) { - tensorflow::gtl::FlatSet seen_so_far; - auto has_not_been_seen_yet = [&](const HloInstruction* instruction) { - return seen_so_far.count(instruction) == 0; - }; - - for (auto instruction : sequence) { - if (std::any_of(instruction->operands().begin(), - instruction->operands().end(), has_not_been_seen_yet) || - std::any_of(instruction->control_predecessors().begin(), - instruction->control_predecessors().end(), - has_not_been_seen_yet)) { - return false; // Not a post order. - } - if (!seen_so_far.insert(instruction).second) { - return false; // Not a "traversal". - } - } - - return true; -} - TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) { auto module = CreateNewModule(); auto builder = HloComputation::Builder(TestName()); @@ -2284,14 +2283,14 @@ TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) { auto weights0 = builder.AddInstruction( HloInstruction::CreateParameter(1, data_shape_, "weights0")); auto output0 = builder.AddInstruction( - HloInstruction::CreateBroadcast(data_shape_, zero, {1})); + HloInstruction::CreateBroadcast(data_shape_, zero, {})); auto input1 = builder.AddInstruction( HloInstruction::CreateParameter(2, data_shape_, "input1")); auto weights1 = builder.AddInstruction( HloInstruction::CreateParameter(3, data_shape_, "weights1")); auto output1 = builder.AddInstruction( - HloInstruction::CreateBroadcast(data_shape_, one, {1})); + HloInstruction::CreateBroadcast(data_shape_, one, {})); auto cond = module->AddEmbeddedComputation(BuildWhileConditionComputation("cond")); @@ -2311,41 +2310,40 @@ TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) { HloInstruction::CreateGetTupleElement(data_shape_, while0, 0)); auto gte1 = builder.AddInstruction( HloInstruction::CreateGetTupleElement(data_shape_, while1, 1)); - auto root_add = builder.AddInstruction(HloInstruction::CreateBinary( - while0->shape(), HloOpcode::kAdd, gte0, gte1)); + auto root_add = builder.AddInstruction( + HloInstruction::CreateBinary(data_shape_, HloOpcode::kAdd, gte0, gte1)); module->AddEntryComputation(builder.Build()); { FlattenCallGraph flatten; - TF_ASSERT_OK_AND_ASSIGN(bool result, flatten.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool result, flatten.Run(module)); EXPECT_TRUE(result); } - RunCopyInsertion(module.get()); + RunCopyInsertion(module); - auto sequence = - ScheduleComputationsInModule(*module, ByteSizeOf).ConsumeValueOrDie(); + HloSchedule schedule = + ScheduleModule(*module, ByteSizeOf).ConsumeValueOrDie(); - // To trigger b/38494731, we want a specific Hlo sequence for the + // To trigger b/38494731, we want a specific Hlo schedule for the // root computation, so we overwrite that entry with a manually // crafted sequence. - sequence[module->entry_computation()] = { - input1, weights1, one, output1, while1->operand(0), while1, - input0, weights0, zero, output0, while0->operand(0), while0, - gte0, gte1, root_add}; + schedule.set_sequence(module->entry_computation(), + {input1, weights1, one, output1, while1->operand(0), + while1, input0, weights0, zero, output0, + while0->operand(0), while0, gte0, gte1, root_add}); - // If this ASSERT_TRUE fails, we constructed a bogus sequence above - // and this test itself is buggy. - ASSERT_TRUE(IsPostOrderTraversal(sequence[module->entry_computation()])); + // If this ASSERT fails, we constructed a bogus sequence above and this test + // itself is buggy. + TF_ASSERT_OK(schedule.Verify()); auto assignment = - BufferAssigner::Run( - module.get(), - absl::make_unique(module.get(), sequence), - ByteSizeOf, [](LogicalBuffer::Color) { return 1; }, - /*allow_input_output_aliasing=*/false, - /*allocate_buffers_for_constants=*/true) + BufferAssigner::Run(module, + absl::make_unique(schedule), + ByteSizeOf, [](LogicalBuffer::Color) { return 1; }, + /*allow_input_output_aliasing=*/false, + /*allocate_buffers_for_constants=*/true) .ConsumeValueOrDie(); EXPECT_TRUE(BuffersDistinct({while0}, {while1}, *assignment)); @@ -2363,9 +2361,9 @@ TEST_F(WhileBufferAssignmentTest, WhilesDontShareEntryParamIfLiveOut) { auto zero = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0))); auto output0 = builder.AddInstruction( - HloInstruction::CreateBroadcast(data_shape_, zero, {1})); + HloInstruction::CreateBroadcast(data_shape_, zero, {})); auto output1 = builder.AddInstruction( - HloInstruction::CreateBroadcast(data_shape_, zero, {1})); + HloInstruction::CreateBroadcast(data_shape_, zero, {})); auto cond0 = module->AddEmbeddedComputation(BuildWhileConditionComputation("cond")); @@ -2396,8 +2394,8 @@ TEST_F(WhileBufferAssignmentTest, WhilesDontShareEntryParamIfLiveOut) { HloInstruction::CreateGetTupleElement(data_shape_, while1, 2)); module->AddEntryComputation(builder.Build()); - RunCopyInsertion(module.get()); - auto assignment = RunBufferAssignment(module.get()); + RunCopyInsertion(module); + auto assignment = RunBufferAssignment(module); // Get BufferAllocation for root instruction. auto* root_alloc = assignment->GetUniqueTopLevelSlice(while1_out) .ConsumeValueOrDie() diff --git a/tensorflow/compiler/xla/service/buffer_liveness.cc b/tensorflow/compiler/xla/service/buffer_liveness.cc index 8d0ac3b84a90dccef4732cc2e63e3a24741f4932..9b2783a214a686f3148723d19bbc94421fc8b4e4 100644 --- a/tensorflow/compiler/xla/service/buffer_liveness.cc +++ b/tensorflow/compiler/xla/service/buffer_liveness.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" @@ -29,7 +30,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" namespace xla { @@ -75,19 +75,17 @@ Status BufferLiveness::Analyze() { string BufferLiveness::ToString() const { std::vector pieces; - pieces.push_back(tensorflow::strings::Printf("BufferLiveness(module=%s):", - module_->name().c_str())); + pieces.push_back( + absl::StrFormat("BufferLiveness(module=%s):", module_->name())); pieces.push_back("HloOrdering:"); pieces.push_back(hlo_ordering_->ToString()); - pieces.push_back(tensorflow::strings::Printf("Aliased buffers:")); + pieces.push_back("Aliased buffers:"); for (const LogicalBuffer* buffer : aliased_buffers_) { - pieces.push_back( - tensorflow::strings::Printf(" %s", buffer->ToString().c_str())); + pieces.push_back(absl::StrFormat(" %s", buffer->ToString())); } - pieces.push_back(tensorflow::strings::Printf("Live out buffers:")); + pieces.push_back("Live out buffers:"); for (const LogicalBuffer* buffer : maybe_live_out_buffers_) { - pieces.push_back( - tensorflow::strings::Printf(" %s", buffer->ToString().c_str())); + pieces.push_back(absl::StrFormat(" %s", buffer->ToString())); } return absl::StrJoin(pieces, "\n"); } diff --git a/tensorflow/compiler/xla/service/buffer_liveness.h b/tensorflow/compiler/xla/service/buffer_liveness.h index cdd3cf4032ef6916086e1c2d148b575192503000..f939a426ead7c34092fc5234ef779ee857347a26 100644 --- a/tensorflow/compiler/xla/service/buffer_liveness.h +++ b/tensorflow/compiler/xla/service/buffer_liveness.h @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" @@ -27,8 +28,6 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/flatmap.h" -#include "tensorflow/core/lib/gtl/flatset.h" namespace xla { @@ -102,7 +101,7 @@ class BufferLiveness { // Set of LogicalBuffers which are aliased in the output of other // instructions. For example, a LogicalBuffer which is inserted into a tuple // is considered to be aliased and will be in this set. - tensorflow::gtl::FlatSet aliased_buffers_; + absl::flat_hash_set aliased_buffers_; // LogicalBuffers that may be live out of the entry computation. PointsToSet::BufferSet maybe_live_out_buffers_; diff --git a/tensorflow/compiler/xla/service/buffer_liveness_test.cc b/tensorflow/compiler/xla/service/buffer_liveness_test.cc index 26e26e316d6281a97f8317f8ed1d7a6f21b0d374..17e50905059ad2c92784d14132c1cb1f46f35ade 100644 --- a/tensorflow/compiler/xla/service/buffer_liveness_test.cc +++ b/tensorflow/compiler/xla/service/buffer_liveness_test.cc @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" namespace xla { namespace { @@ -166,12 +167,12 @@ TEST_F(BufferLivenessTest, MultipleEntryParameters_Sequential) { auto module = CreateNewModule(); HloComputation* entry = module->AddEntryComputation(builder.Build()); - SequentialHloOrdering::HloModuleSequence sequence; - sequence.insert({entry, {param0, negate, param1, exp, add}}); - auto liveness = BufferLiveness::Run(module.get(), - absl::make_unique( - module.get(), sequence)) - .ConsumeValueOrDie(); + HloSchedule schedule(module.get()); + schedule.set_sequence(entry, {param0, negate, param1, exp, add}); + auto liveness = + BufferLiveness::Run(module.get(), + absl::make_unique(schedule)) + .ConsumeValueOrDie(); // Entry parameters interfere as if they are defined simultaneously at // the very beginning. @@ -291,13 +292,12 @@ TEST_F(BufferLivenessTest, OverlappedBuffersSequentialOrder) { auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - SequentialHloOrdering::HloModuleSequence module_sequence; - std::vector order = {param, negate, exp, add}; - module_sequence.emplace(computation, order); - auto liveness = BufferLiveness::Run(module.get(), - absl::make_unique( - module.get(), module_sequence)) - .ConsumeValueOrDie(); + HloSchedule schedule(module.get()); + schedule.set_sequence(computation, {param, negate, exp, add}); + auto liveness = + BufferLiveness::Run(module.get(), + absl::make_unique(schedule)) + .ConsumeValueOrDie(); EXPECT_TRUE(InstructionsMayInterfere(*liveness, param, negate)); EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, exp)); @@ -339,14 +339,14 @@ TEST_F(BufferLivenessTest, RootInstructionIsNotLastInSequentialOrder) { auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build(add)); - SequentialHloOrdering::HloModuleSequence module_sequence; - std::vector order = {param, add, recv, - recv_done, send, send_done}; - module_sequence.emplace(computation, order); - auto liveness = BufferLiveness::Run(module.get(), - absl::make_unique( - module.get(), module_sequence)) - .ConsumeValueOrDie(); + HloSchedule schedule(module.get()); + schedule.set_sequence(computation, + {param, add, token, recv, recv_done, send, send_done}); + TF_ASSERT_OK(schedule.Verify()); + auto liveness = + BufferLiveness::Run(module.get(), + absl::make_unique(schedule)) + .ConsumeValueOrDie(); EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, add)); // Check the root instruction (add) buffer interferes with the recv buffer. @@ -440,15 +440,15 @@ TEST_F(BufferLivenessTest, TupleConstantLiveOut) { // computation. The buffer containing {0, 1} is copied by GetTupleElement, and // the buffers containing {3} and 3 are dead. auto builder = HloComputation::Builder(TestName()); - auto inner_tuple0 = - LiteralUtil::MakeTuple({LiteralUtil::CreateR0(0).get(), - LiteralUtil::CreateR0(1).get()}); - auto inner_tuple1 = - LiteralUtil::MakeTuple({LiteralUtil::CreateR0(3).get()}); + Literal elements0[] = {LiteralUtil::CreateR0(0), + LiteralUtil::CreateR0(1)}; + auto inner_tuple0 = LiteralUtil::MakeTuple({&elements0[0], &elements0[1]}); + Literal element1 = LiteralUtil::CreateR0(3); + auto inner_tuple1 = LiteralUtil::MakeTuple({&element1}); auto tuple_constant = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::MakeTuple({inner_tuple0.get(), inner_tuple1.get()}))); + LiteralUtil::MakeTuple({&inner_tuple0, &inner_tuple1}))); builder.AddInstruction(HloInstruction::CreateGetTupleElement( - inner_tuple0->shape(), tuple_constant, 0)); + inner_tuple0.shape(), tuple_constant, 0)); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); diff --git a/tensorflow/compiler/xla/service/buffer_value.h b/tensorflow/compiler/xla/service/buffer_value.h index f4be16e0843f64f41ef27539bf263ae98ce0ebf9..11d8abc5badf7b1a05239ed74a05be0c899e37a1 100644 --- a/tensorflow/compiler/xla/service/buffer_value.h +++ b/tensorflow/compiler/xla/service/buffer_value.h @@ -19,12 +19,12 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/gtl/int_type.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" @@ -141,6 +141,9 @@ class BufferValue { // operator< is required for std::set. bool operator<(const BufferValue& other) const { return id_ < other.id_; } + bool operator==(const BufferValue& other) const { return id_ == other.id_; } + bool operator!=(const BufferValue& other) const { return id_ != other.id_; } + virtual string ToString() const = 0; // TODO(lauj) rename LogicalBufferProto to BufferValueProto. diff --git a/tensorflow/compiler/xla/service/buffer_value_containers.h b/tensorflow/compiler/xla/service/buffer_value_containers.h index 305914fca828f110bf54239bddb1590172562b16..cc46af5eeec623e19637cd6245915b3a3124a2cd 100644 --- a/tensorflow/compiler/xla/service/buffer_value_containers.h +++ b/tensorflow/compiler/xla/service/buffer_value_containers.h @@ -16,10 +16,10 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_VALUE_CONTAINERS_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_VALUE_CONTAINERS_H_ +#include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/service/buffer_value.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" #include "tensorflow/core/lib/gtl/compactptrset.h" -#include "tensorflow/core/lib/gtl/flatset.h" namespace xla { @@ -38,7 +38,7 @@ BufferValueCompactPointerSet ToBufferValueCompactPointerSet( return output; } -using BufferValueFlatSet = tensorflow::gtl::FlatSet; +using BufferValueFlatSet = absl::flat_hash_set; template BufferValueFlatSet ToBufferValueFlatSet( const LogicalBufferContainerT& logical_buffer_container) { diff --git a/tensorflow/compiler/xla/service/call_graph.cc b/tensorflow/compiler/xla/service/call_graph.cc index 37523a73ff403cc079038abe0975045ba6bf7361..bdd5069632e84fe6c67ca129f726432479ac1b35 100644 --- a/tensorflow/compiler/xla/service/call_graph.cc +++ b/tensorflow/compiler/xla/service/call_graph.cc @@ -17,21 +17,22 @@ limitations under the License. #include +#include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/types.h" namespace xla { +using absl::StrAppendFormat; using absl::StrCat; -using ::tensorflow::strings::Appendf; string CallContextToString(CallContext context) { switch (context) { @@ -138,7 +139,7 @@ CallGraphNode& CallGraph::GetNode(const HloComputation* computation) { bool CallGraph::DominatesHelper( const HloComputation* a, const HloComputation* b, - tensorflow::gtl::FlatSet* visited) const { + absl::flat_hash_set* visited) const { if (a == b || ContainsKey(*visited, b)) { // The call graph is guaranteed to be acyclic so any previously visited node // we encounter was already determined to be dominated. @@ -163,7 +164,7 @@ bool CallGraph::DominatesHelper( bool CallGraph::Dominates(const HloComputation* a, const HloComputation* b) const { - tensorflow::gtl::FlatSet visited; + absl::flat_hash_set visited; return DominatesHelper(a, b, &visited); } @@ -277,7 +278,7 @@ std::unique_ptr CallGraph::Build(const HloModule* module) { Status CallGraph::VisitNodesInternal( const VisitorFunction& visitor_func, const CallGraphNode& node, - tensorflow::gtl::FlatSet* visited) const { + absl::flat_hash_set* visited) const { auto pair = visited->insert(&node); if (!pair.second) { // Node was not inserted. Node has already been visited. @@ -294,7 +295,7 @@ Status CallGraph::VisitNodesInternal( Status CallGraph::VisitNodes(const VisitorFunction& visitor_func, bool visit_unreachable_nodes) const { - tensorflow::gtl::FlatSet visited; + absl::flat_hash_set visited; if (visit_unreachable_nodes) { // Traverse from all roots in the call graph. for (const CallGraphNode& node : nodes()) { @@ -356,20 +357,20 @@ CallGraph::NearestAncestorsInSameComputation(HloInstruction* a, string CallGraph::ToString() const { string out; - Appendf(&out, "Call graph for module %s:\n", module_->name().c_str()); + StrAppendFormat(&out, "Call graph for module %s:\n", module_->name()); for (const CallGraphNode& node : nodes()) { - Appendf(&out, "Computation %s:\n", node.computation()->name().c_str()); - Appendf(&out, " calls:\n"); + StrAppendFormat(&out, "Computation %s:\n", node.computation()->name()); + StrAppendFormat(&out, " calls:\n"); for (const HloComputation* callee : node.callees()) { - Appendf(&out, " %s\n", callee->name().c_str()); + StrAppendFormat(&out, " %s\n", callee->name()); } - Appendf(&out, " called by:\n"); + StrAppendFormat(&out, " called by:\n"); for (const HloComputation* caller : node.callers()) { - Appendf(&out, " %s\n", caller->name().c_str()); + StrAppendFormat(&out, " %s\n", caller->name()); } - Appendf(&out, " callsites:\n"); + StrAppendFormat(&out, " callsites:\n"); for (const CallSite& callsite : node.callsites()) { - Appendf(&out, " %s\n", callsite.ToString().c_str()); + StrAppendFormat(&out, " %s\n", callsite.ToString()); } } return out; diff --git a/tensorflow/compiler/xla/service/call_graph.h b/tensorflow/compiler/xla/service/call_graph.h index 3af2ab5edfd9faf4ac5193df4b823c21b55b2f7f..cb56f4789d06ac33acdaadc8b619b9e37f683d58 100644 --- a/tensorflow/compiler/xla/service/call_graph.h +++ b/tensorflow/compiler/xla/service/call_graph.h @@ -20,11 +20,11 @@ limitations under the License. #include +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" -#include "tensorflow/core/lib/gtl/flatmap.h" -#include "tensorflow/core/lib/gtl/flatset.h" namespace xla { @@ -145,19 +145,19 @@ class CallGraphNode { // The computations called by this computation. The vector is used for a // stable ordering and the set enables fast membership testing. std::vector callees_; - tensorflow::gtl::FlatSet callee_set_; + absl::flat_hash_set callee_set_; // The computations which call this computation. The vector is used for a // stable ordering and the set enables fast membership testing. std::vector callers_; - tensorflow::gtl::FlatSet caller_set_; + absl::flat_hash_set caller_set_; // The call sites in this computation std::vector callsites_; // The map from instruction to index in callsites_ for looking up the callsite // (if any) associated with a particular instruction in this computation. - tensorflow::gtl::FlatMap callsite_instructions_; + absl::flat_hash_map callsite_instructions_; // The call sites in other computations which call this computation. std::vector caller_callsites_; @@ -250,14 +250,14 @@ class CallGraph { // 'visited'. Status VisitNodesInternal( const VisitorFunction& visitor_func, const CallGraphNode& node, - tensorflow::gtl::FlatSet* visited) const; + absl::flat_hash_set* visited) const; // Recursive helper for computing whether 'a' dominates 'b' in the call // graph. 'b_ancestor' is the currently visited node (which starts at 'b'), // and 'visited' is the set of computations which have been visited. bool DominatesHelper( const HloComputation* a, const HloComputation* b, - tensorflow::gtl::FlatSet* visited) const; + absl::flat_hash_set* visited) const; // The HLO module represented by this call graph. const HloModule* module_ = nullptr; @@ -267,7 +267,7 @@ class CallGraph { // Map from HLO computation to the index of the corresponding call graph node // in nodes_. - tensorflow::gtl::FlatMap node_indices_; + absl::flat_hash_map node_indices_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/call_graph_test.cc b/tensorflow/compiler/xla/service/call_graph_test.cc index cc80b7484313329104eec1ce71a150b47d8330c9..34f3f914d593bc603c4964663f9cafb70a136fd3 100644 --- a/tensorflow/compiler/xla/service/call_graph_test.cc +++ b/tensorflow/compiler/xla/service/call_graph_test.cc @@ -21,7 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -31,7 +31,7 @@ namespace { using ::testing::UnorderedElementsAre; -class CallGraphTest : public HloTestBase { +class CallGraphTest : public HloVerifiedTestBase { protected: // Build and return a trivial computation taking and returning a scalar. std::unique_ptr MakeScalarComputation( @@ -96,7 +96,7 @@ TEST_F(CallGraphTest, SingletonComputation) { auto module = CreateNewModule(); HloComputation* computation = module->AddEntryComputation(MakeScalarComputation()); - std::unique_ptr call_graph = CallGraph::Build(module.get()); + std::unique_ptr call_graph = CallGraph::Build(module); EXPECT_EQ(1, call_graph->nodes().size()); EXPECT_TRUE(call_graph->IsFlattened()); @@ -118,7 +118,7 @@ TEST_F(CallGraphTest, UnreachableComputation) { HloComputation* unreachable_computation = module->AddEmbeddedComputation(MakeScalarComputation()); - std::unique_ptr call_graph = CallGraph::Build(module.get()); + std::unique_ptr call_graph = CallGraph::Build(module); EXPECT_EQ(2, call_graph->nodes().size()); const CallGraphNode& entry_node = call_graph->GetNode(entry_computation); @@ -140,7 +140,7 @@ TEST_F(CallGraphTest, ParallelComputation) { HloComputation* entry_computation = module->AddEntryComputation( MakeMappingComputation(map_computation, /*callsites=*/5)); - std::unique_ptr call_graph = CallGraph::Build(module.get()); + std::unique_ptr call_graph = CallGraph::Build(module); EXPECT_EQ(2, call_graph->nodes().size()); const CallGraphNode& entry_node = call_graph->GetNode(entry_computation); @@ -169,7 +169,7 @@ TEST_F(CallGraphTest, SequentialComputations) { HloComputation* entry_computation = module->AddEntryComputation( MakeCallingComputation(called_computation, /*callsites=*/3)); - std::unique_ptr call_graph = CallGraph::Build(module.get()); + std::unique_ptr call_graph = CallGraph::Build(module); EXPECT_EQ(2, call_graph->nodes().size()); // The called computation is only called from one other computation, but there @@ -210,7 +210,7 @@ TEST_F(CallGraphTest, ContextBothComputations) { HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - std::unique_ptr call_graph = CallGraph::Build(module.get()); + std::unique_ptr call_graph = CallGraph::Build(module); EXPECT_EQ(2, call_graph->nodes().size()); EXPECT_FALSE(call_graph->IsFlattened()); @@ -259,7 +259,7 @@ TEST_F(CallGraphTest, ComputationWithConditional) { HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - std::unique_ptr call_graph = CallGraph::Build(module.get()); + std::unique_ptr call_graph = CallGraph::Build(module); EXPECT_EQ(3, call_graph->nodes().size()); @@ -328,7 +328,7 @@ TEST_F(CallGraphTest, ComplexGraph) { entry_computation = module->AddEntryComputation(builder.Build()); } - std::unique_ptr call_graph = CallGraph::Build(module.get()); + std::unique_ptr call_graph = CallGraph::Build(module); EXPECT_EQ(5, call_graph->nodes().size()); EXPECT_FALSE(call_graph->IsFlattened()); @@ -452,7 +452,7 @@ TEST_F(CallGraphTest, ComplexGraphNearestAncestors) { entry_computation = module->AddEntryComputation(builder.Build()); } - std::unique_ptr call_graph = CallGraph::Build(module.get()); + std::unique_ptr call_graph = CallGraph::Build(module); EXPECT_EQ(5, call_graph->nodes().size()); // Verify NearestAncestorsInSameComputation for various instructions in the @@ -482,7 +482,7 @@ TEST_F(CallGraphTest, VisitSingletonComputation) { auto module = CreateNewModule(); HloComputation* computation = module->AddEntryComputation(MakeScalarComputation()); - std::unique_ptr call_graph = CallGraph::Build(module.get()); + std::unique_ptr call_graph = CallGraph::Build(module); std::vector visited; TF_ASSERT_OK(call_graph->VisitNodes([&visited](const CallGraphNode& node) { @@ -499,7 +499,7 @@ TEST_F(CallGraphTest, VisitUnreachableComputation) { module->AddEntryComputation(MakeScalarComputation()); HloComputation* unreachable_computation = module->AddEmbeddedComputation(MakeScalarComputation()); - std::unique_ptr call_graph = CallGraph::Build(module.get()); + std::unique_ptr call_graph = CallGraph::Build(module); // Test visitation of only reachable nodes. { @@ -533,7 +533,7 @@ TEST_F(CallGraphTest, VisitWithError) { // Test that the call graph visitor properly propagates errors. auto module = CreateNewModule(); module->AddEntryComputation(MakeScalarComputation()); - std::unique_ptr call_graph = CallGraph::Build(module.get()); + std::unique_ptr call_graph = CallGraph::Build(module); Status status = call_graph->VisitNodes( [](const CallGraphNode&) { return InternalError("Visitation failed"); }); diff --git a/tensorflow/compiler/xla/service/call_inliner.cc b/tensorflow/compiler/xla/service/call_inliner.cc index 256d05a73e0bf61d959d21795c106286b52d0b19..1d4214044409ae06239506e610000c839450a030 100644 --- a/tensorflow/compiler/xla/service/call_inliner.cc +++ b/tensorflow/compiler/xla/service/call_inliner.cc @@ -96,7 +96,7 @@ class SubcomputationInsertionVisitor : public DfsHloVisitorWithDefault { if (it == subcomputation_hlo_to_new_hlo_.end()) { return NotFound( "Could not find mapping from subcomputation HLO %s to a cloned HLO.", - subcomputation_hlo->ToString().c_str()); + subcomputation_hlo->ToString()); } return it->second; } diff --git a/tensorflow/compiler/xla/service/call_inliner.h b/tensorflow/compiler/xla/service/call_inliner.h index c5cd88b9ea2a9c308786d4d7476316b1e592d40a..08c4aff4f7fc7fc332fc7f34ece019eb57d71f3a 100644 --- a/tensorflow/compiler/xla/service/call_inliner.h +++ b/tensorflow/compiler/xla/service/call_inliner.h @@ -25,7 +25,7 @@ namespace xla { // For every kCall operation in the main computation, we inline the body of the // called function, and proceed recursively. -class CallInliner : public HloPassInterface { +class CallInliner : public HloModulePass { public: using InlinedInstructionMap = std::unordered_map; diff --git a/tensorflow/compiler/xla/service/call_inliner_test.cc b/tensorflow/compiler/xla/service/call_inliner_test.cc index 5d85a3f173d50a964420e720f5c9b416731d948c..e6b566543594a86eb5369ee9b7440f62618f6c5a 100644 --- a/tensorflow/compiler/xla/service/call_inliner_test.cc +++ b/tensorflow/compiler/xla/service/call_inliner_test.cc @@ -28,7 +28,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_pass_fix.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -40,7 +40,7 @@ namespace { // Tests for call inlining that are most tractable at the HLO level (vs // ComputationBuilder API in call_test.cc). -using CallInlinerTest = HloTestBase; +using CallInlinerTest = HloVerifiedTestBase; TEST_F(CallInlinerTest, ControlDependenciesAreCarriedToCaller) { // "inner" computation just has a control dependency from the "zero" value to @@ -64,7 +64,7 @@ TEST_F(CallInlinerTest, ControlDependenciesAreCarriedToCaller) { auto computation = module->AddEntryComputation(outer.Build()); CallInliner call_inliner; - TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module)); ASSERT_TRUE(mutated); EXPECT_THAT(computation->root_instruction(), op::Constant()); EXPECT_EQ(computation->root_instruction()->literal().GetFirstElement(), @@ -91,6 +91,8 @@ TEST_F(CallInlinerTest, CallsWithinWhileBodiesAreInlined) { module->AddEmbeddedComputation(just_false.Build()); HloComputation::Builder call_false_builder(TestName() + ".call_false"); + call_false_builder.AddInstruction( + HloInstruction::CreateParameter(0, pred, "param")); call_false_builder.AddInstruction( HloInstruction::CreateCall(pred, {}, false_computation)); HloComputation* call_false = @@ -105,7 +107,7 @@ TEST_F(CallInlinerTest, CallsWithinWhileBodiesAreInlined) { auto computation = module->AddEntryComputation(outer.Build()); CallInliner call_inliner; - TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module)); ASSERT_TRUE(mutated); EXPECT_THAT( computation->root_instruction()->while_condition()->root_instruction(), @@ -161,7 +163,7 @@ TEST_F(CallInlinerTest, CallToOutfeedComputationIsInlined) { module->AddEntryComputation(outer.Build()); CallInliner call_inliner; - TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module)); ASSERT_TRUE(mutated); } diff --git a/tensorflow/compiler/xla/service/channel_tracker.cc b/tensorflow/compiler/xla/service/channel_tracker.cc index 601a3e9a01b83fffe09354c37cc3565ad6abdc72..3c2d1ae6d82ebc6c10d52194fd1cec5e291025f7 100644 --- a/tensorflow/compiler/xla/service/channel_tracker.cc +++ b/tensorflow/compiler/xla/service/channel_tracker.cc @@ -73,20 +73,20 @@ ChannelHandle ChannelTracker::AllocateHandle(ChannelHandle::ChannelType type) { Status ChannelTracker::RegisterSendInternal(const ChannelHandle& handle) { if (opaque_to_channel_.count(handle.handle()) == 0) { - return NotFound("channel handle not found: %lld", handle.handle()); + return NotFound("channel handle not found: %d", handle.handle()); } Channel& channel = opaque_to_channel_[handle.handle()]; if (channel.type == ChannelHandle::HOST_TO_DEVICE) { return FailedPrecondition( "host-to-device channels cannot be used with a Send operation; " - "channel handle: %lld", + "channel handle: %d", handle.handle()); } if (channel.has_sender) { return FailedPrecondition( "when registering send, passed a channel handle that is already used " - "by a sender: %lld", + "by a sender: %d", handle.handle()); } channel.has_sender = true; @@ -95,13 +95,13 @@ Status ChannelTracker::RegisterSendInternal(const ChannelHandle& handle) { Status ChannelTracker::RegisterRecvInternal(const ChannelHandle& handle) { if (opaque_to_channel_.count(handle.handle()) == 0) { - return NotFound("channel handle not found: %lld", handle.handle()); + return NotFound("channel handle not found: %d", handle.handle()); } Channel& channel = opaque_to_channel_[handle.handle()]; if (channel.type == ChannelHandle::DEVICE_TO_HOST) { return FailedPrecondition( "device-to-host channels cannot be used with a Recv operation; " - "channel handle: %lld", + "channel handle: %d", handle.handle()); } @@ -109,7 +109,7 @@ Status ChannelTracker::RegisterRecvInternal(const ChannelHandle& handle) { if (channel.receiver_count >= 1) { return FailedPrecondition( "when registering recv, passed a channel handle that is already used " - "by a receiver: %lld", + "by a receiver: %d", handle.handle()); } channel.receiver_count += 1; diff --git a/tensorflow/compiler/xla/service/channel_tracker.h b/tensorflow/compiler/xla/service/channel_tracker.h index d773558c284a7d645f2766bb88c50f7da3777e5d..52037bf9b52556c6aa2e66dd3209e25cf085cfe3 100644 --- a/tensorflow/compiler/xla/service/channel_tracker.h +++ b/tensorflow/compiler/xla/service/channel_tracker.h @@ -18,12 +18,12 @@ limitations under the License. #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/thread_annotations.h" diff --git a/tensorflow/compiler/xla/service/compile_only_service.cc b/tensorflow/compiler/xla/service/compile_only_service.cc index 3079695e9674f4000fdf4c54ac1e78c98968aa27..6d67f970020d278cc7bf61b56350200d3e5cb926 100644 --- a/tensorflow/compiler/xla/service/compile_only_service.cc +++ b/tensorflow/compiler/xla/service/compile_only_service.cc @@ -62,12 +62,12 @@ CompileOnlyService::CompileOnlyService(const ServiceOptions& options, StatusOr>> CompileOnlyService::CompileAheadOfTime( - const tensorflow::gtl::ArraySlice computations, + const absl::Span computations, const AotCompilationOptions& options, std::unique_ptr* metadata) { std::vector> hlo_modules; for (const AotXlaComputationInstance& instance : computations) { - TF_RET_CHECK(instance.computation.has_program_shape()); + TF_RET_CHECK(instance.computation.has_host_program_shape()); const DebugOptions& debug_options = options.debug_options(); @@ -86,9 +86,11 @@ CompileOnlyService::CompileAheadOfTime( Executable::DumpToDirectory(per_host_path, filename, hlo_snapshot)); } - const auto& program_shape = instance.computation.program_shape(); + const auto& program_shape = instance.computation.host_program_shape(); ExecutionOptions execution_options; *execution_options.mutable_debug_options() = debug_options; + *execution_options.mutable_shape_with_output_layout() = + *instance.result_layout; TF_ASSIGN_OR_RETURN( std::unique_ptr module_config, CreateModuleConfig(program_shape, instance.argument_layouts, @@ -97,12 +99,14 @@ CompileOnlyService::CompileAheadOfTime( TF_ASSIGN_OR_RETURN( std::unique_ptr hlo_module, HloModule::CreateFromProto(instance.computation, *module_config)); - TF_RETURN_IF_ERROR(MaybeDumpHloModule(*hlo_module)); + TF_RETURN_IF_ERROR(MaybeDumpUnoptimizedHloModule(*hlo_module)); hlo_modules.push_back(std::move(hlo_module)); } - return compiler_->CompileAheadOfTime(std::move(hlo_modules), options, - metadata); + return compiler_->CompileAheadOfTime( + absl::make_unique(hlo_modules[0]->name(), + absl::MakeSpan(hlo_modules)), + options, metadata); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/compile_only_service.h b/tensorflow/compiler/xla/service/compile_only_service.h index 1ac950bdd66bd034dfdafa8598ec506221e99c2f..61136a3e11fe15fb74eac257f46292c6cd24ce7d 100644 --- a/tensorflow/compiler/xla/service/compile_only_service.h +++ b/tensorflow/compiler/xla/service/compile_only_service.h @@ -50,12 +50,12 @@ class CompileOnlyService : public Service { // |CompileOnlyClient::CompileAheadOfTime| for additional details. StatusOr>> CompileAheadOfTime( - const tensorflow::gtl::ArraySlice computations, + const absl::Span computations, const AotCompilationOptions& options); StatusOr>> CompileAheadOfTime( - const tensorflow::gtl::ArraySlice computations, + const absl::Span computations, const AotCompilationOptions& options, std::unique_ptr* metadata); diff --git a/tensorflow/compiler/xla/service/compiler.cc b/tensorflow/compiler/xla/service/compiler.cc index 6b3b9820f09803c8a04504e6c35c22de51abf04b..80c630c6201503d88a690f04a88f6fca6f3a438a 100644 --- a/tensorflow/compiler/xla/service/compiler.cc +++ b/tensorflow/compiler/xla/service/compiler.cc @@ -45,7 +45,7 @@ Compiler::ComputeDefaultBackendConfig(const HloInstruction& hlo, // Define a default version where metadata is not used. StatusOr>> Compiler::CompileAheadOfTime( - std::vector> modules, + std::unique_ptr module_group, const AotCompilationOptions& options, std::unique_ptr* metadata) { if (metadata != nullptr) { @@ -53,7 +53,7 @@ Compiler::CompileAheadOfTime( "Populating AotCompilationMetadata is not implemented on this " "compiler."); } - return CompileAheadOfTime(std::move(modules), options); + return CompileAheadOfTime(std::move(module_group), options); } /* static */ std::map* @@ -101,7 +101,7 @@ Compiler::GetPlatformCompilers() { return NotFound( "could not find registered compiler for platform %s -- check " "target linkage", - platform->Name().c_str()); + platform->Name()); } // And then we invoke the factory, placing the result into the mapping. diff --git a/tensorflow/compiler/xla/service/compiler.h b/tensorflow/compiler/xla/service/compiler.h index 34f7fe12cac5a4dcd3822865bee903d6eabc25c0..9ab179303b3e792c1f94c08626d7bc1afd2099f8 100644 --- a/tensorflow/compiler/xla/service/compiler.h +++ b/tensorflow/compiler/xla/service/compiler.h @@ -26,15 +26,16 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/service/buffer_value.h" #include "tensorflow/compiler/xla/service/executable.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" +#include "tensorflow/compiler/xla/service/hlo_module_group.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -135,6 +136,12 @@ class Compiler { std::unique_ptr module, se::StreamExecutor* executor, DeviceMemoryAllocator* device_allocator) = 0; + // Optimizes a HLO module group, a set of module which runs concurrently on + // multiple devices potentially communicating data between the modules. + virtual Status RunHloPassesOnModuleGroup( + HloModuleGroup* module_group, se::StreamExecutor* executor, + DeviceMemoryAllocator* device_allocator) = 0; + // Compiles the HLO module for execution on a device given by the executor, // and returns an executable object or an error status. No HLO passes are // applied to module. Generally a module should be passed through RunHloPasses @@ -145,12 +152,18 @@ class Compiler { // (not just type of device) indicated by the executor. // // device_allocator is optional; see RunHloPasses. - // - // Use the overload below to compile computations that run in parallel. virtual StatusOr> RunBackend( std::unique_ptr module, se::StreamExecutor* executor, DeviceMemoryAllocator* device_allocator) = 0; + // Compiles a set of HLO modules that can run in parallel, potentially + // communicating data between the modules. + virtual StatusOr>> + RunBackendOnModuleGroup( + std::unique_ptr module_group, + std::vector> stream_exec, + DeviceMemoryAllocator* device_allocator) = 0; + // Compiles a set of HLO modules that can run in parallel, potentially // communicating data between the modules, and returns a corresponding // sequence of executable objects. @@ -160,7 +173,7 @@ class Compiler { // TODO(b/68666782): Remove this method after adding support for multiple // modules to RunHloPasses and RunBackends. virtual StatusOr>> Compile( - std::vector> modules, + std::unique_ptr module_group, std::vector> stream_exec, DeviceMemoryAllocator* device_allocator) = 0; @@ -184,16 +197,16 @@ class Compiler { ComputeDefaultBackendConfig(const HloInstruction& hlo, se::StreamExecutor* executor) const; - // Compiles the HLO module for ahead-of-time execution. This is intended for - // use in static compilation. + // Compiles the HLO module group for ahead-of-time execution. This is + // intended for use in static compilation. virtual StatusOr>> - CompileAheadOfTime(std::vector> modules, + CompileAheadOfTime(std::unique_ptr module_group, const AotCompilationOptions& options) = 0; // Similar to CompileAheadOfTime above but AotCompilationMetadata // has an argument that can be populated during compilation. virtual StatusOr>> - CompileAheadOfTime(std::vector> modules, + CompileAheadOfTime(std::unique_ptr module_group, const AotCompilationOptions& options, std::unique_ptr* metadata); diff --git a/tensorflow/compiler/xla/service/computation_layout.cc b/tensorflow/compiler/xla/service/computation_layout.cc index af8f7f1027a40703137d6880a9865449c560a47b..efc893818d03a20d6bd65b7dc1da72ea5da5ceb0 100644 --- a/tensorflow/compiler/xla/service/computation_layout.cc +++ b/tensorflow/compiler/xla/service/computation_layout.cc @@ -56,4 +56,14 @@ string ComputationLayout::ToString() const { result_layout_.ToString()); } +ProgramShape ComputationLayout::ComputeProgramShape() const { + ProgramShape program_shape; + for (int64 i = 0; i < parameter_layouts_.size(); ++i) { + *program_shape.add_parameters() = parameter_layouts_[i].shape(); + *program_shape.add_parameter_names() = absl::StrCat("p", i); + } + *program_shape.mutable_result() = result_layout_.shape(); + return program_shape; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/computation_layout.h b/tensorflow/compiler/xla/service/computation_layout.h index 6975f387b4864bf28ea0ad23d7d4602b5b346e08..a2fb656677f354fbf85ff613d826cd6be86ba3bf 100644 --- a/tensorflow/compiler/xla/service/computation_layout.h +++ b/tensorflow/compiler/xla/service/computation_layout.h @@ -83,6 +83,10 @@ class ComputationLayout { // Returns a string representation of this object. string ToString() const; + // Create a ProgramShape proto based on the parameter and result shapes held + // within this object. + ProgramShape ComputeProgramShape() const; + private: std::vector parameter_layouts_; ShapeLayout result_layout_; diff --git a/tensorflow/compiler/xla/service/computation_placer.cc b/tensorflow/compiler/xla/service/computation_placer.cc index 61b1dba6c9222dc487003eb08189ee71eaafedd2..2210a8578ad73efb27dc9c230b142c55228d2af5 100644 --- a/tensorflow/compiler/xla/service/computation_placer.cc +++ b/tensorflow/compiler/xla/service/computation_placer.cc @@ -132,7 +132,7 @@ StatusOr ComputationPlacer::AssignDevices( return NotFound( "could not find registered computation placer for platform %s -- check " "target linkage", - platform->Name().c_str()); + platform->Name()); } if (it->second.placer == nullptr) { diff --git a/tensorflow/compiler/xla/service/conditional_simplifier.h b/tensorflow/compiler/xla/service/conditional_simplifier.h index 3de50cbd7ff752e8722a103b68f75144c6c889cd..2223ad67534dc31fc2c56ce68bdc87e881f20f32 100644 --- a/tensorflow/compiler/xla/service/conditional_simplifier.h +++ b/tensorflow/compiler/xla/service/conditional_simplifier.h @@ -25,7 +25,7 @@ namespace xla { // HLO pass that removes kConditional with a constant predicate, replacing them // with their true or false computation as appropriate. -class ConditionalSimplifier : public HloPassInterface { +class ConditionalSimplifier : public HloModulePass { public: absl::string_view name() const override { return "simplify-conditional"; } StatusOr Run(HloModule* module) override; diff --git a/tensorflow/compiler/xla/service/conditional_simplifier_test.cc b/tensorflow/compiler/xla/service/conditional_simplifier_test.cc index 6c477da03820681e381dd64978d30edf27e2c422..c43a31b167d47af3c92ed35fa52594fa5da1e4af 100644 --- a/tensorflow/compiler/xla/service/conditional_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/conditional_simplifier_test.cc @@ -39,10 +39,6 @@ namespace op = xla::testing::opcode_matchers; class ConditionalSimplifierTest : public HloVerifiedTestBase { public: - ConditionalSimplifierTest() - : HloVerifiedTestBase(/*layout_sensitive=*/false, - /*allow_mixed_precision=*/false) {} - // Makes a computation that contains a conditional with constant predicate. HloComputation* MakeConditional(HloModule* module); }; diff --git a/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc b/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc index 9c81a86bbb9dc7078237fe200f510a4905cb4d8d..0ac4a65ec6ae55fabd2b48ea2982b94f9551c8d2 100644 --- a/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc +++ b/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc @@ -214,8 +214,8 @@ Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) { expanded_filter = add(HloInstruction::CreateConcatenate( expanded_filter_shape, concat_operands, input_feature_dim)); } - auto zero = add(HloInstruction::CreateConstant(absl::make_unique( - LiteralUtil::Zero(expanded_filter_shape.element_type())))); + auto zero = add(HloInstruction::CreateConstant( + LiteralUtil::Zero(expanded_filter_shape.element_type()))); auto zero_filter = add(HloInstruction::CreateBroadcast(expanded_filter_shape, zero, {})); auto new_filter = add( @@ -223,8 +223,8 @@ Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) { filter_mask, expanded_filter, zero_filter)); auto new_convolution = HloInstruction::CreateConvolve( convolution->shape(), convolution->mutable_operand(0), new_filter, - convolution->window(), dim_numbers, /*feature_group_count=*/1); - new_convolution->set_precision_config(convolution->precision_config()); + /*feature_group_count=*/1, convolution->window(), dim_numbers, + convolution->precision_config()); TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction( convolution, std::move(new_convolution))); return Status::OK(); diff --git a/tensorflow/compiler/xla/service/convolution_feature_group_converter.h b/tensorflow/compiler/xla/service/convolution_feature_group_converter.h index 498894737fa37a6d8cca6ead2a86c72eb84ababd..ce0138e56fbd51daaf5d3ac329ccbe31a9fdbde7 100644 --- a/tensorflow/compiler/xla/service/convolution_feature_group_converter.h +++ b/tensorflow/compiler/xla/service/convolution_feature_group_converter.h @@ -25,7 +25,7 @@ namespace xla { // A pass which rewrites convolutions with feature_group_count > 1 into // convolutions with feature_group_count = 1. -class ConvolutionFeatureGroupConverter : public HloPassInterface { +class ConvolutionFeatureGroupConverter : public HloModulePass { public: ConvolutionFeatureGroupConverter() {} diff --git a/tensorflow/compiler/xla/service/copy_insertion.cc b/tensorflow/compiler/xla/service/copy_insertion.cc index 1b7a7b36eac31f972e1166e17859cc0c64265538..245db6be2a400a7447f1e87317018cbb1572c405 100644 --- a/tensorflow/compiler/xla/service/copy_insertion.cc +++ b/tensorflow/compiler/xla/service/copy_insertion.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/copy_insertion.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/service/hlo_alias_analysis.h" @@ -31,8 +33,6 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/gtl/flatmap.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/logging.h" namespace xla { @@ -40,10 +40,12 @@ namespace { using absl::StrAppend; -bool IsEntryParameterValue(const HloValue& value) { +bool IsReadonlyEntryParameterValue(const HloValue& value) { const HloComputation* computation = value.defining_instruction()->parent(); return value.defining_instruction()->opcode() == HloOpcode::kParameter && - computation == computation->parent()->entry_computation(); + computation == computation->parent()->entry_computation() && + !computation->parent()->input_output_alias_config().ParameterHasAlias( + value.defining_instruction()->parameter_number(), value.index()); } bool IsConstantValue(const HloValue& value) { @@ -51,7 +53,7 @@ bool IsConstantValue(const HloValue& value) { } bool ValueIsReadOnly(const HloValue& value) { - return IsConstantValue(value) || IsEntryParameterValue(value); + return IsConstantValue(value) || IsReadonlyEntryParameterValue(value); } // Data structure describing the action which should be taken on parts of a @@ -79,8 +81,7 @@ SpecialCaseCopyPolicy GetSpecialCaseCopyPolicy(const CallGraphNode& node, bool ShouldCopyRootValue(const HloValue& value, const SpecialCaseCopyPolicy& policy) { if (policy.copy_parameters_and_constants) { - return IsConstantValue(value) || - value.defining_instruction()->opcode() == HloOpcode::kParameter; + return ValueIsReadOnly(value); } return false; } @@ -332,6 +333,81 @@ Status AddCopiesForConditional(const HloAliasAnalysis& alias_analysis, return Status::OK(); } +// Conservatively adds copies before root instruction of entry computation and +// each aliased parameter to resolve interference of aliased input and output +// buffer. We later rely on the CopyRemover to drop the unnecessary ones. +Status AddCopiesForAliasedInputOutputs(HloModule* module) { + HloComputation* entry = module->entry_computation(); + HloInstruction* root = entry->root_instruction(); + + ShapeTree output_indices_to_copy(root->shape()); + std::vector> copied_parameters; + bool has_alias = false; + for (auto* param : entry->parameter_instructions()) { + bool param_has_alias = false; + ShapeTree param_indices_to_copy(param->shape()); + + module->input_output_alias_config().ForEachAlias( + [&](const ShapeIndex& output_index, int64 param_number, + const ShapeIndex& param_index) { + if (param_number == param->parameter_number()) { + param_has_alias = true; + *(param_indices_to_copy.mutable_element(param_index)) = true; + *(output_indices_to_copy.mutable_element(output_index)) = true; + } + }); + + if (!param_has_alias) { + continue; + } + + has_alias = true; + // Store a snapshot of users before DeepCopyInstruction, as + // DeepCopyInstruction introduces new users of the instruction. + std::vector users = param->users(); + ShapeTree param_copy_tree(param->shape(), + /*init_value=*/nullptr); + TF_ASSIGN_OR_RETURN(HloInstruction * copied, + entry->DeepCopyInstruction( + param, ¶m_indices_to_copy, ¶m_copy_tree)); + for (HloInstruction* user : users) { + TF_RETURN_IF_ERROR(param->ReplaceUseWith(user, copied)); + } + + copied_parameters.push_back(param_copy_tree); + } + + if (!has_alias) { + return Status::OK(); + } + + // Add copies before root instruction. + ShapeTree output_copy_tree(root->shape(), + /*init_value=*/nullptr); + + TF_ASSIGN_OR_RETURN(HloInstruction * root_copied, + root->parent()->DeepCopyInstruction( + root, &output_indices_to_copy, &output_copy_tree)); + + // Add control dependencies between the input/output copies. + TF_RETURN_IF_ERROR(module->input_output_alias_config().ForEachAliasWithStatus( + [&](const ShapeIndex& output_index, int64 param_number, + const ShapeIndex& input_index) -> Status { + HloInstruction* from = + copied_parameters[param_number].element(input_index); + HloInstruction* to = output_copy_tree.element(output_index); + + TF_RET_CHECK(from != nullptr); + TF_RET_CHECK(to != nullptr); + TF_RETURN_IF_ERROR(from->AddControlDependencyTo(to)); + return Status::OK(); + })); + + entry->set_root_instruction(root_copied); + + return Status::OK(); +} + // Removes any control dependencies to or from the given instruction. Status StripControlDependenciesFrom(HloInstruction* instruction) { while (!instruction->control_successors().empty()) { @@ -432,7 +508,7 @@ class CopyRemover { // Construct a list for each HLO buffer in the alias analysis. Maintain a // map from HloValue to the respective list element representing that // value. The map is used to construct the copy info map below. - tensorflow::gtl::FlatMap value_to_node; + absl::flat_hash_map value_to_node; for (const HloBuffer& buffer : alias_analysis.buffers()) { // Verify values contained in the buffer are strictly ordered. This // should always be the case after adding copies to eliminate @@ -479,8 +555,8 @@ class CopyRemover { // 'values' an entry is created in value_to_node which indicates the // respective ValueNode representing that value. void AddValueList( - tensorflow::gtl::ArraySlice values, - tensorflow::gtl::FlatMap* value_to_node) { + absl::Span values, + absl::flat_hash_map* value_to_node) { ValueNode* tail = nullptr; ValueNode* head = nullptr; for (const HloValue* value : values) { @@ -516,8 +592,7 @@ class CopyRemover { // respective ValueNode. void CreateCopyMap( const HloModule& module, - const tensorflow::gtl::FlatMap& - value_to_node) { + const absl::flat_hash_map& value_to_node) { for (HloComputation* computation : module.computations()) { for (HloInstruction* instruction : computation->instructions()) { // Add copies with unambiguous source values to the map. Copies with @@ -905,7 +980,7 @@ class CopyRemover { // The heads of all the value lists. Each value list represents the HLO // values contained in a particular HLO buffer. The values in the list are // in dependency order. - tensorflow::gtl::FlatSet value_lists_; + absl::flat_hash_set value_lists_; // Copy removal requires fast access to the value list elements // corresponding to the source and destination values of the kCopy @@ -916,7 +991,7 @@ class CopyRemover { ValueNode* src = nullptr; ValueNode* dest = nullptr; }; - tensorflow::gtl::FlatMap copy_map_; + absl::flat_hash_map copy_map_; }; HloModule* module_; @@ -954,6 +1029,8 @@ Status CopyInsertion::AddCopiesToResolveInterference(HloModule* module) { } } } + + TF_RETURN_IF_ERROR(AddCopiesForAliasedInputOutputs(module)); return Status::OK(); } @@ -1010,7 +1087,7 @@ Status CopyInsertion::AddSpecialCaseCopies(const CallGraph& call_graph, HloInstruction* root = computation->root_instruction(); // Mark nondistinct/ambiguous indices. - tensorflow::gtl::FlatSet seen; + absl::flat_hash_set seen; ShapeUtil::ForEachSubshape( root->shape(), [&](const Shape& /*subshape*/, const ShapeIndex& index) { std::vector buffers_at_index = diff --git a/tensorflow/compiler/xla/service/copy_insertion.h b/tensorflow/compiler/xla/service/copy_insertion.h index d308f6bc84670b78b9cab476f2893bce267df2cf..c097089e30d59936a32f69c49123c398f1611ea3 100644 --- a/tensorflow/compiler/xla/service/copy_insertion.h +++ b/tensorflow/compiler/xla/service/copy_insertion.h @@ -43,7 +43,7 @@ namespace xla { // (3) The buffer set of the root instruction of the entry computation must be // unambiguous and distinct. That is, InstructionAliasSet::IsAmbiguous and // InstructionAliasSet::IsDistinct return true. -class CopyInsertion : public HloPassInterface { +class CopyInsertion : public HloModulePass { public: absl::string_view name() const override { return "copy-insertion"; } diff --git a/tensorflow/compiler/xla/service/copy_insertion_test.cc b/tensorflow/compiler/xla/service/copy_insertion_test.cc index 892d0d7b547aaf1e7f1c55e4163d1e1fd9518def..4533ebb99bbba854a029fb8a9a1e31b023be720d 100644 --- a/tensorflow/compiler/xla/service/copy_insertion_test.cc +++ b/tensorflow/compiler/xla/service/copy_insertion_test.cc @@ -1351,6 +1351,218 @@ TEST_F(CopyInsertionTest, SwizzlingWhile) { EXPECT_THAT(xla_while->operand(0), op::Tuple(op::Copy(), op::Copy())); } +TEST_F(CopyInsertionTest, CrossingParameters) { + // Test a case where two parameters' dataflow cross with each other while + // input and output are aliased with same index: + // + // (p0 , p1) + // | \ /| + // | \ / | + // alias X alias + // | / \ | + // | / \| + // (p1 , p0) + auto module = CreateNewModule(); + const Shape tuple_shape = + ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); + + auto builder = HloComputation::Builder(TestName()); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "0")); + auto gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0)); + auto gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1)); + builder.AddInstruction(HloInstruction::CreateTuple({gte1, gte0})); + module->AddEntryComputation(builder.Build()); + ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias( + /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0})); + ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias( + /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1})); + InsertCopies(module.get()); + + EXPECT_EQ(CountCopies(*module), 4); +} + +TEST_F(CopyInsertionTest, ParametersAliasing) { + // Test a case where two parameters' dataflow don't interfere with each other + // while aliased. + // + // (p0 , p1) + // | | + // | | + // alias alias + // | | + // | | + // (p0 , p1) + auto module = CreateNewModule(); + const Shape tuple_shape = + ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); + + auto builder = HloComputation::Builder(TestName()); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "p0")); + auto gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0)); + auto gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1)); + builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1})); + module->AddEntryComputation(builder.Build()); + ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias( + /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0})); + ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias( + /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1})); + InsertCopies(module.get()); + + EXPECT_EQ(CountCopies(*module), 0); +} + +TEST_F(CopyInsertionTest, ParameterWithNoAliasing) { + // Test a case where no parameter is aliased with result. In this case, copy + // should be added + // + // (p0 , p1) + // | | + // | | + // | | + // | | + // | | + // (p0 , p1) + auto module = CreateNewModule(); + const Shape tuple_shape = + ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); + + auto builder = HloComputation::Builder(TestName()); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "p0")); + auto gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0)); + auto gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1)); + builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1})); + module->AddEntryComputation(builder.Build()); + InsertCopies(module.get()); + + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Tuple(op::Copy(op::GetTupleElement(param, 0)), + op::Copy(op::GetTupleElement(param, 1)))); + + EXPECT_EQ(CountCopies(*module), 2); +} + +TEST_F(CopyInsertionTest, ParameterWithPartialAliasing) { + // Test a case where one parameter is aliased with result while another one + // isn't. + // + // (p0 , p1) + // | | + // | | + // alias | + // | | + // | | + // (p0 , p1) + auto module = CreateNewModule(); + const Shape tuple_shape = + ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); + + auto builder = HloComputation::Builder(TestName()); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "p0")); + auto gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0)); + auto gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1)); + builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1})); + module->AddEntryComputation(builder.Build()); + ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias( + /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0})); + InsertCopies(module.get()); + + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Tuple(op::GetTupleElement(param, 0), + op::Copy(op::GetTupleElement(param, 1)))); + + EXPECT_EQ(CountCopies(*module), 1); +} + +TEST_F(CopyInsertionTest, ParameterAndParallelOpsWithPartialAliasing) { + // Test a case where one parameter is aliased with result while another one + // isn't. + // + // +-- (p0 , p1) + // | | | + // | | | + // alias Negate Negate + // | | | + // | | | + // +-- (p0 , p1) + auto module = CreateNewModule(); + const Shape tuple_shape = + ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); + + auto builder = HloComputation::Builder(TestName()); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "p0")); + auto gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0)); + auto gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1)); + + auto negate0 = builder.AddInstruction( + HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, gte0)); + + auto negate1 = builder.AddInstruction( + HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, gte1)); + builder.AddInstruction(HloInstruction::CreateTuple({negate0, negate1})); + module->AddEntryComputation(builder.Build()); + ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias( + /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0})); + InsertCopies(module.get()); + + EXPECT_EQ(CountCopies(*module), 0); +} + +TEST_F(CopyInsertionTest, ParameterAndOpsWithPartialAliasing) { + // Test a case where one parameter is aliased with result while another one + // isn't. + // + // +-- (p0 , p1) + // | | | + // | | | + // alias Negate Negate + // | | | + // | Add----+ + // | | | + // +-- (p0 , p1) + auto module = CreateNewModule(); + const Shape tuple_shape = + ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); + + auto builder = HloComputation::Builder(TestName()); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "p0")); + auto gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0)); + auto gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1)); + + auto negate0 = builder.AddInstruction( + HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, gte0)); + + auto negate1 = builder.AddInstruction( + HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, gte1)); + + auto add = builder.AddInstruction(HloInstruction::CreateBinary( + scalar_shape_, HloOpcode::kAdd, negate0, negate1)); + builder.AddInstruction(HloInstruction::CreateTuple({add, negate1})); + module->AddEntryComputation(builder.Build()); + ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias( + /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0})); + InsertCopies(module.get()); + + EXPECT_EQ(CountCopies(*module), 0); +} + TEST_F(CopyInsertionTest, SwizzlingWhileWithOneOp) { // Test a while instruction with a body which permutes its tuple parameter // elements and applies one operation to one of the elements. The addition of diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index e01fecffd00e50cb06f9f19eb44de9d329547298..58abb330a6e31e9b7a8081cd7964cf89a5b64a09 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -50,7 +50,9 @@ cc_library( "//tensorflow/compiler/xla/service/cpu:cpu_runtime", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/stream_executor", "@com_google_absl//absl/memory", + "@com_google_absl//absl/types:span", ], alwayslink = True, # Contains per-platform transfer manager registration ) @@ -63,6 +65,7 @@ cc_library( "//tensorflow/compiler/tf2xla:cpu_function_runtime", "//tensorflow/compiler/xla/service:buffer_assignment", "//tensorflow/core:lib", + "@com_google_absl//absl/types:span", ], ) @@ -89,7 +92,9 @@ cc_library( "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", ":target_machine_features", + "@com_google_absl//absl/types:span", "//tensorflow/compiler/tf2xla:cpu_function_runtime", + "//tensorflow/compiler/xla/service:map_inliner", "//tensorflow/compiler/xla/service:scatter_expander", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:protobuf_util", @@ -119,11 +124,10 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_pass_pipeline", "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/compiler/xla/service:hlo_proto_util", - "//tensorflow/compiler/xla/service:hlo_scheduling", + "//tensorflow/compiler/xla/service:hlo_memory_scheduler", "//tensorflow/compiler/xla/service:hlo_subcomputation_unification", "//tensorflow/compiler/xla/service:hlo_verifier", "//tensorflow/compiler/xla/service:indexed_array_analysis", - "//tensorflow/compiler/xla/service:inliner", "//tensorflow/compiler/xla/service:llvm_compiler", "//tensorflow/compiler/xla/service:reduce_precision_insertion", "//tensorflow/compiler/xla/service:reshape_mover", @@ -177,6 +181,7 @@ cc_library( ":runtime_conv2d_mkl", ":runtime_fft", ":runtime_fork_join", + ":runtime_key_value_sort", ":runtime_matmul", ":runtime_matmul_mkl", ":runtime_single_threaded_conv2d", @@ -235,6 +240,8 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", "@llvm//:orc_jit", ], ) @@ -277,12 +284,17 @@ cc_library( "//tensorflow/compiler/xla/service/llvm_ir:dynamic_update_slice_util", "//tensorflow/compiler/xla/service/llvm_ir:fused_ir_emitter", "//tensorflow/compiler/xla/service/llvm_ir:ir_array", + "//tensorflow/compiler/xla/service/llvm_ir:ir_builder_mixin", "//tensorflow/compiler/xla/service/llvm_ir:llvm_loop", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/compiler/xla/service/llvm_ir:loop_emitter", "//tensorflow/compiler/xla/service/llvm_ir:tuple_ops", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", "@llvm//:code_gen", "@llvm//:core", "@llvm//:support", @@ -299,6 +311,7 @@ cc_library( deps = [ "//tensorflow/compiler/xla:shape_util", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", "@llvm//:analysis", "@llvm//:target", ], @@ -328,6 +341,7 @@ cc_library( "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/core:lib", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", "@llvm//:core", ], ) @@ -338,12 +352,12 @@ cc_library( hdrs = ["parallel_loop_emitter.h"], deps = [ ":ir_emission_utils", - "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service/llvm_ir:ir_array", "//tensorflow/compiler/xla/service/llvm_ir:llvm_loop", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/compiler/xla/service/llvm_ir:loop_emitter", "//tensorflow/core:lib", + "@com_google_absl//absl/strings:str_format", "@llvm//:core", ], ) @@ -391,6 +405,7 @@ tf_cc_binary( "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/core:lib", + "@com_google_absl//absl/strings:str_format", ], ) @@ -404,6 +419,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/strings:str_format", "@llvm//:mc", "@llvm//:mc_disassembler", "@llvm//:object", @@ -450,12 +466,17 @@ cc_library( ], copts = runtime_copts(), deps = [ + "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/core:lib", + "//tensorflow/stream_executor", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:span", ], ) @@ -611,6 +632,18 @@ cc_library( ], ) +cc_library( + name = "runtime_key_value_sort", + srcs = ["runtime_key_value_sort.cc"], + hdrs = ["runtime_key_value_sort.h"], + copts = runtime_copts(), + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/core:framework_lite", + "//third_party/eigen3", + ], +) + cc_library( name = "runtime_fork_join", srcs = ["runtime_fork_join.cc"], @@ -645,6 +678,7 @@ tf_cc_test( "//tensorflow/core:test", "//third_party/eigen3", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings:str_format", ], ) @@ -657,9 +691,11 @@ tf_cc_test( "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/service:transpose_folding", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", ], ) @@ -730,6 +766,7 @@ cc_library( "//tensorflow/compiler/xla/service:computation_layout", "//tensorflow/compiler/xla/service:layout_assignment", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", ], ) @@ -754,6 +791,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:lib", + "@com_google_absl//absl/types:span", ], ) @@ -785,6 +823,7 @@ tf_cc_test( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], ) @@ -806,6 +845,7 @@ tf_cc_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], ) @@ -909,6 +949,7 @@ cc_library( "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/core:lib", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/types:span", "@llvm//:core", "@llvm//:support", ], @@ -929,6 +970,7 @@ tf_cc_test( "//tensorflow/compiler/xla/service:hlo_graph_dumper", "//tensorflow/compiler/xla/service:hlo_matchers", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", ], @@ -954,6 +996,7 @@ tf_cc_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", diff --git a/tensorflow/compiler/xla/service/cpu/buffer_info_util.cc b/tensorflow/compiler/xla/service/cpu/buffer_info_util.cc index 408fe0f5bf5d729165eadd532d4740211620645d..1942ea1a2af8a349de53bafe80977436f9740fc4 100644 --- a/tensorflow/compiler/xla/service/cpu/buffer_info_util.cc +++ b/tensorflow/compiler/xla/service/cpu/buffer_info_util.cc @@ -40,7 +40,7 @@ std::vector CreateBufferInfosFromBufferAssignment( } std::vector CreateArgIndexTableFromBufferInfos( - tensorflow::gtl::ArraySlice buffer_infos) { + absl::Span buffer_infos) { std::vector result; for (int64 i = 0; i < buffer_infos.size(); i++) { if (buffer_infos[i].is_entry_parameter()) { diff --git a/tensorflow/compiler/xla/service/cpu/buffer_info_util.h b/tensorflow/compiler/xla/service/cpu/buffer_info_util.h index 05de70c72686dcbdaf0b47c46cde23ed45abdb42..e9ee928ab290f2f5338bd7b3804dc43033e2042f 100644 --- a/tensorflow/compiler/xla/service/cpu/buffer_info_util.h +++ b/tensorflow/compiler/xla/service/cpu/buffer_info_util.h @@ -16,9 +16,9 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_BUFFER_INFO_UTIL_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_BUFFER_INFO_UTIL_H_ +#include "absl/types/span.h" #include "tensorflow/compiler/tf2xla/cpu_function_runtime.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" -#include "tensorflow/core/lib/gtl/array_slice.h" namespace xla { namespace cpu { @@ -34,7 +34,7 @@ CreateBufferInfosFromBufferAssignment( // If this function returns V then entry parameter i has buffer allocation index // V[i]. std::vector CreateArgIndexTableFromBufferInfos( - tensorflow::gtl::ArraySlice<::tensorflow::cpu_function_runtime::BufferInfo> + absl::Span buffer_infos); } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc index 098ce17a568fd3fb531020e7731100fabda43721..2d9978404cc9ec1e40fc61aaf794a8f1f06050bb 100644 --- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc +++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc @@ -130,9 +130,9 @@ StatusOr ConvCanonicalization::Run(HloModule* module) { // change the dimension mapping but not the dimension sizes. For // example, input height and width are the same as before the reshapes. HloInstruction* new_conv = module->entry_computation()->AddInstruction( - HloInstruction::CreateConvolve(new_conv_shape, new_input, new_kernel, - hlo->window(), new_dnums)); - new_conv->set_precision_config(hlo->precision_config()); + HloInstruction::CreateConvolve( + new_conv_shape, new_input, new_kernel, hlo->feature_group_count(), + hlo->window(), new_dnums, hlo->precision_config())); // Reshape the output back to the shape of the original convolution. TF_RETURN_IF_ERROR(module->entry_computation()->ReplaceWithNewInstruction( diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.h b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.h index 59437e88af27528654a0af86baf69ec7a1e91d60..becee3f81fc34c73040d53e4a261bc3a656cd78c 100644 --- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.h +++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.h @@ -31,7 +31,7 @@ namespace cpu { // called canonical convolutions). This pass expands non-canonical convolutions // into reshapes and canonical convolutions, so that these non-canonical // convolutions can run faster. -class ConvCanonicalization : public HloPassInterface { +class ConvCanonicalization : public HloModulePass { public: explicit ConvCanonicalization( const TargetMachineFeatures* target_machine_features) diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc b/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc index 547d4c696da5cfdde3dece03250ae5fa51c92f25..2083f440fdd971db1b675d005664d25e6de53dbe 100644 --- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc +++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/test_helpers.h" @@ -32,7 +32,7 @@ namespace cpu { using ::testing::ElementsAre; -class ConvCanonicalizationTest : public HloTestBase { +class ConvCanonicalizationTest : public HloVerifiedTestBase { public: ConvCanonicalizationTest() { for (int i = 0; i < 2; ++i) { @@ -84,7 +84,8 @@ TEST_F(ConvCanonicalizationTest, NonCanonicalToCanonical) { builder.AddInstruction(HloInstruction::CreateConvolve( ShapeUtil::MakeShape( F32, {kOutputFeatureCount, kBatchSize, output_size, output_size}), - input, kernel, conv_window_, dnums)); + input, kernel, /*feature_group_count=*/1, conv_window_, dnums, + DefaultPrecisionConfig(2))); auto module = CreateNewModule(); HloComputation* entry_computation = @@ -95,7 +96,7 @@ TEST_F(ConvCanonicalizationTest, NonCanonicalToCanonical) { return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment; }); ConvCanonicalization conv_canonicalization(&target_machine_features); - EXPECT_TRUE(conv_canonicalization.Run(module.get()).ValueOrDie()); + EXPECT_TRUE(conv_canonicalization.Run(module).ValueOrDie()); const HloInstruction* output_reshape = entry_computation->root_instruction(); EXPECT_EQ(HloOpcode::kTranspose, output_reshape->opcode()); @@ -146,7 +147,8 @@ TEST_F(ConvCanonicalizationTest, CanonicalStaysTheSame) { builder.AddInstruction(HloInstruction::CreateConvolve( ShapeUtil::MakeShape( F32, {kBatchSize, output_size, output_size, kOutputFeatureCount}), - input, kernel, conv_window_, dnums)); + input, kernel, /*feature_group_count=*/1, conv_window_, dnums, + DefaultPrecisionConfig(2))); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); @@ -156,7 +158,7 @@ TEST_F(ConvCanonicalizationTest, CanonicalStaysTheSame) { return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment; }); ConvCanonicalization conv_canonicalization(&target_machine_features); - EXPECT_FALSE(conv_canonicalization.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(conv_canonicalization.Run(module).ValueOrDie()); } } // namespace cpu diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index 279aa42fe23e5f1f1eeaf9f6303097a6e1a8f8a1..da01c0caf2a6665f71cc087270b21fffdd6caa0d 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -77,17 +77,17 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_dce.h" #include "tensorflow/compiler/xla/service/hlo_element_type_converter.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" #include "tensorflow/compiler/xla/service/hlo_pass_fix.h" #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h" #include "tensorflow/compiler/xla/service/hlo_proto_util.h" -#include "tensorflow/compiler/xla/service/hlo_scheduling.h" #include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h" #include "tensorflow/compiler/xla/service/hlo_verifier.h" #include "tensorflow/compiler/xla/service/indexed_array_analysis.h" -#include "tensorflow/compiler/xla/service/inliner.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" +#include "tensorflow/compiler/xla/service/map_inliner.h" #include "tensorflow/compiler/xla/service/reduce_precision_insertion.h" #include "tensorflow/compiler/xla/service/reshape_mover.h" #include "tensorflow/compiler/xla/service/scatter_expander.h" @@ -249,9 +249,7 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn( &pipeline, module->config().debug_options(), ReducePrecisionInsertion::PassTiming::BEFORE_OPTIMIZATION); - // TODO(b/35786417): Re-enable inliner pass after fixing the bug and deciding - // where we will take this pass in future. - // pipeline.AddPass(); + pipeline.AddPass(); // TODO(b/65775800): Fix wrong output bug in Call and remove the CallInliner // pass. @@ -308,7 +306,8 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn( ReducePrecisionInsertion::PassTiming::AFTER_FUSION); pipeline.AddPass( - module->mutable_entry_computation_layout(), target_machine_features); + module->mutable_entry_computation_layout(), + LayoutAssignment::InstructionCanChangeLayout, target_machine_features); return pipeline.Run(module).status(); } @@ -328,8 +327,13 @@ Status CpuCompiler::RunHloPassesAfterLayoutAssn( { auto& pass = pipeline.AddPass>( "simplification after layout assignement"); - pass.AddInvariantChecker(/*layout_sensitive=*/true, - /*allow_mixed_precision=*/false); + // TODO(b/117156505): When the bug is fixed, the CPU backend should not + // produce layout changing elementwise operations. We will then pass + // LayoutAssignment::InstructionCanChangeLayout to the HLO verifier to + // enable stricter verification. + pass.AddInvariantChecker( + /*layout_sensitive=*/true, + /*allow_mixed_precision=*/false); pass.AddPass>( /*is_layout_sensitive=*/true, [](const Shape&, const Shape&) { return true; }, @@ -584,17 +588,14 @@ StatusOr> CpuCompiler::RunBackend( // computation. Using this sequence enables tighter buffer liveness analysis // and reduced memory usage (as compared to using DependencyHloOrdering). TF_ASSIGN_OR_RETURN( - SequentialHloOrdering::HloModuleSequence module_sequence, - ScheduleComputationsInModule(*module, BufferSizeBytesFunction(), - DFSMemoryScheduler)); + HloSchedule schedule, + ScheduleModule(*module, BufferSizeBytesFunction(), DFSMemoryScheduler)); - // Run buffer analysis on the HLO graph. This analysis figures out which - // temporary buffers are required to run the computation. + // Run buffer allocation on the HLO graph. TF_ASSIGN_OR_RETURN( std::unique_ptr assignment, BufferAssigner::Run(module.get(), - absl::make_unique( - module.get(), module_sequence), + absl::make_unique(schedule), BufferSizeBytesFunction(), memory_alignment, /*allow_input_output_aliasing=*/false, /*allocate_buffers_for_constants=*/true)); @@ -628,9 +629,10 @@ StatusOr> CpuCompiler::RunBackend( } TF_RETURN_IF_ERROR( ir_emitter - .EmitComputation(embedded_computation, embedded_computation->name(), - /*is_top_level_computation=*/false, - &module_sequence.at(embedded_computation)) + .EmitComputation( + embedded_computation, embedded_computation->name(), + /*is_top_level_computation=*/false, + &schedule.sequence(embedded_computation).instructions()) .status()); } string function_name_prefix = entry_computation->name().empty() @@ -638,9 +640,10 @@ StatusOr> CpuCompiler::RunBackend( : entry_computation->name(); TF_ASSIGN_OR_RETURN( llvm::Function * entry_function, - ir_emitter.EmitComputation(entry_computation, function_name_prefix, - /*is_top_level_computation=*/true, - &module_sequence.at(entry_computation))); + ir_emitter.EmitComputation( + entry_computation, function_name_prefix, + /*is_top_level_computation=*/true, + &schedule.sequence(entry_computation).instructions())); string function_name = [&]() { llvm::SmallVector function_name_vector; @@ -673,9 +676,12 @@ StatusOr> CpuCompiler::RunBackend( } StatusOr>> -CpuCompiler::CompileAheadOfTime(std::vector> modules, +CpuCompiler::CompileAheadOfTime(std::unique_ptr module_group, const AotCompilationOptions& aot_options) { - TF_RET_CHECK(!modules.empty()); + TF_RET_CHECK(!module_group->empty()); + std::vector> modules = + module_group->ConsumeModules(); + std::call_once(llvm_command_line_options_initialized, &llvm_ir::InitializeLLVMCommandLineOptions, modules[0]->config()); @@ -705,8 +711,7 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, const llvm::Target* target = llvm::TargetRegistry::lookupTarget(triple.getTriple(), error); if (target == nullptr) { - return InternalError("TargetRegistry::lookupTarget failed: %s", - error.c_str()); + return InternalError("TargetRegistry::lookupTarget failed: %s", error); } llvm::Reloc::Model reloc_model = llvm::Reloc::Static; @@ -773,20 +778,18 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, VLOG(2) << "After optimization:"; XLA_VLOG_LINES(2, module->ToString()); - TF_ASSIGN_OR_RETURN( - SequentialHloOrdering::HloModuleSequence module_sequence, - ScheduleComputationsInModule(*module, BufferSizeBytesFunction())); + TF_ASSIGN_OR_RETURN(HloSchedule schedule, + ScheduleModule(*module, BufferSizeBytesFunction())); // Run buffer analysis on the HLO graph. This analysis figures out which // temporary buffers are required to run the computation. TF_ASSIGN_OR_RETURN( std::unique_ptr assignment, - BufferAssigner::Run( - module, - absl::make_unique(module, module_sequence), - BufferSizeBytesFunction(), memory_alignment, - /*allow_input_output_aliasing=*/false, - /*allocate_buffers_for_constants=*/true)); + BufferAssigner::Run(module, + absl::make_unique(schedule), + BufferSizeBytesFunction(), memory_alignment, + /*allow_input_output_aliasing=*/false, + /*allocate_buffers_for_constants=*/true)); // BufferAssignment::ToString() includes a header, so no need for us to // print one ourselves. XLA_VLOG_LINES(2, assignment->ToString()); @@ -826,18 +829,18 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, } TF_RETURN_IF_ERROR( ir_emitter - .EmitComputation(embedded_computation, - embedded_computation->name(), - /*is_top_level_computation=*/false, - &module_sequence.at(embedded_computation)) + .EmitComputation( + embedded_computation, embedded_computation->name(), + /*is_top_level_computation=*/false, + &schedule.sequence(embedded_computation).instructions()) .status()); } const string& entry_point_name = options.entry_point_name(); - TF_ASSIGN_OR_RETURN( - llvm::Function * entry_function, - ir_emitter.EmitComputation(computation, entry_point_name, - /*is_top_level_computation=*/true, - &module_sequence.at(computation))); + TF_ASSIGN_OR_RETURN(llvm::Function * entry_function, + ir_emitter.EmitComputation( + computation, entry_point_name, + /*is_top_level_computation=*/true, + &schedule.sequence(computation).instructions())); CHECK(entry_function->getName() == llvm_ir::AsStringRef(entry_point_name)); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.h b/tensorflow/compiler/xla/service/cpu/cpu_compiler.h index 47b5edabff79d1df23cbeae0823536bbdcd07aaa..c67307548dda731f8fa56b8e6790e7e83f587113 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "absl/types/span.h" #include "llvm/Target/TargetMachine.h" #include "tensorflow/compiler/tf2xla/cpu_function_runtime.h" #include "tensorflow/compiler/xla/service/cpu/target_machine_features.h" @@ -25,7 +26,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/llvm_compiler.h" #include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -142,7 +142,7 @@ class CpuCompiler : public LLVMCompiler { DeviceMemoryAllocator* device_allocator) override; StatusOr>> - CompileAheadOfTime(std::vector> modules, + CompileAheadOfTime(std::unique_ptr module_group, const AotCompilationOptions& options) override; se::Platform::Id PlatformId() const override; diff --git a/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h b/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h index d49f7d7cc2d9b1d00847feda62fa62dd740820d8..076235f8874b5de57075fb690dd1b9111b6838a6 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h @@ -30,7 +30,7 @@ namespace xla { // // TODO(b/62548313): Remove this when buffer assignment is smarter // (module-scoped). -class CpuCopyInsertion : public HloPassInterface { +class CpuCopyInsertion : public HloModulePass { public: absl::string_view name() const override { return "copy-insertion"; } diff --git a/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion_test.cc index 4db7fa446ea9188940f930bcadf753bd3e6b79e3..c9fb34be1cd582c71618c770c892058c233c571a 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion_test.cc @@ -25,7 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/test_benchmark.h" @@ -52,7 +52,7 @@ int64 CountCopies(const HloModule& module) { return count; } -class CpuCopyInsertionTest : public HloTestBase { +class CpuCopyInsertionTest : public HloVerifiedTestBase { protected: void InsertCopies(HloModule* module) { CpuCopyInsertion copy_insertion; @@ -90,7 +90,7 @@ TEST_F(CpuCopyInsertionTest, WhileBodyWithConstantRoot) { module->AddEntryComputation(builder.Build()); - InsertCopies(module.get()); + InsertCopies(module); EXPECT_EQ(CountCopies(*module), 3); @@ -127,7 +127,7 @@ TEST_F(CpuCopyInsertionTest, TupleCall) { module->AddEntryComputation(builder.Build()); - InsertCopies(module.get()); + InsertCopies(module); EXPECT_EQ(CountCopies(*subcomputation), 2); EXPECT_THAT(subcomputation->root_instruction(), diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc index fbcbbbd200d80fc18272ac628f230fcf13332aed..29abf38e439d919ff93629ed992cb3ff93a929bd 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc @@ -23,6 +23,7 @@ limitations under the License. #include #include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "llvm/ExecutionEngine/Orc/IRCompileLayer.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" @@ -37,7 +38,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/mem.h" @@ -75,9 +75,9 @@ CpuExecutable::CpuExecutable( StatusOr, std::vector>> -CpuExecutable::CreateTempArray( +CpuExecutable::CreateBufferTable( DeviceMemoryAllocator* memory_allocator, int device_ordinal, - tensorflow::gtl::ArraySlice arguments) { + absl::Span arguments) { std::vector unowning_buffers( assignment_->Allocations().size()); std::vector owning_buffers( @@ -136,19 +136,19 @@ CpuExecutable::CreateTempArray( Status CpuExecutable::ExecuteComputeFunction( const ExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice buffers, + absl::Span buffers, HloExecutionProfile* hlo_execution_profile) { // The calling convention for JITed functions is: // // void function(void* result, const void* run_options, void** args_array, - // void** temps_array) + // void** buffer_table) // // result: Points at the result. // run_options: the ExecutableRunOptions object. // args_array: null - // temps_array: An array of pointers, containing pointers to temporary buffers - // required by the executable adn pointers to entry computation - // parameters. + // buffer_table: An array of pointers, containing pointers to temporary + // buffers required by the executable adn pointers to entry computation + // parameters. // uint64 start_micros = tensorflow::Env::Default()->NowMicros(); @@ -171,20 +171,19 @@ Status CpuExecutable::ExecuteComputeFunction( void* result_buffer = buffer_pointers[result_slice.index()]; if (VLOG_IS_ON(3)) { VLOG(3) << "Executing compute function:"; - VLOG(3) << tensorflow::strings::Printf( - " func(void* result, void* params[null], void* temps[%zu], " - "uint64 profile_counters[%zu])", + VLOG(3) << absl::StrFormat( + " func(void* result, void* params[null], void* buffer_table[%u], " + "uint64 profile_counters[%u])", buffer_pointers.size(), profile_counters_size); - VLOG(3) << tensorflow::strings::Printf(" result = %p", result_buffer); + VLOG(3) << absl::StrFormat(" result = %p", result_buffer); auto ptr_printer = [](string* out, const void* p) { - absl::StrAppend(out, tensorflow::strings::Printf("%p", p)); + absl::StrAppend(out, absl::StrFormat("%p", p)); }; VLOG(3) << " params = nullptr"; - VLOG(3) << tensorflow::strings::Printf( - " temps = [%s]", - absl::StrJoin(buffer_pointers, ", ", ptr_printer).c_str()); - VLOG(3) << tensorflow::strings::Printf(" profile_counters = %p", - profile_counters); + VLOG(3) << absl::StrFormat( + " buffer_table = [%s]", + absl::StrJoin(buffer_pointers, ", ", ptr_printer)); + VLOG(3) << absl::StrFormat(" profile_counters = %p", profile_counters); } compute_function_(result_buffer, run_options, nullptr, buffer_pointers.data(), @@ -209,7 +208,7 @@ Status CpuExecutable::ExecuteComputeFunction( StatusOr CpuExecutable::CreateResultShapedBuffer( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::MutableArraySlice buffers) { + absl::Span buffers) { se::Stream* stream = run_options->stream(); ScopedShapedBuffer result_buffer( /*on_host_shape=*/result_shape(), @@ -247,7 +246,7 @@ StatusOr CpuExecutable::CreateResultShapedBuffer( StatusOr CpuExecutable::ExecuteOnStream( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, HloExecutionProfile* hlo_execution_profile) { TF_ASSIGN_OR_RETURN( auto result, @@ -258,7 +257,7 @@ StatusOr CpuExecutable::ExecuteOnStream( StatusOr CpuExecutable::ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments) { + absl::Span arguments) { if (hlo_profiling_enabled()) { return Unimplemented( "Asynchronous execution on stream with hlo profiling is not yet " @@ -269,7 +268,7 @@ StatusOr CpuExecutable::ExecuteAsyncOnStream( StatusOr CpuExecutable::ExecuteAsyncOnStreamImpl( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, HloExecutionProfile* hlo_execution_profile) { if (GetRootPointsToSet().IsAmbiguous()) { return Unimplemented("Points-to set of root instruction is ambiguous"); @@ -283,11 +282,12 @@ StatusOr CpuExecutable::ExecuteAsyncOnStreamImpl( std::vector unowning_buffers; TF_ASSIGN_OR_RETURN( std::tie(unowning_buffers, owning_buffers), - CreateTempArray(memory_allocator, stream->parent()->device_ordinal(), - arguments)); + CreateBufferTable(memory_allocator, stream->parent()->device_ordinal(), + arguments)); - TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result, - CreateResultShapedBuffer(run_options, &owning_buffers)); + TF_ASSIGN_OR_RETURN( + ScopedShapedBuffer result, + CreateResultShapedBuffer(run_options, absl::MakeSpan(owning_buffers))); // At this point, `unowning_buffers` contains unowning pointers to all of our // buffers, and `buffers` contains owning pointers to the non-live-out @@ -300,7 +300,7 @@ StatusOr CpuExecutable::ExecuteAsyncOnStreamImpl( // // We also need to change the types of some of the variables we capture: // run_options needs to change from a pointer to a value type, and arguments - // needs to change from an ArraySlice into a vector. We use a struct instead + // needs to change from a Span into a vector. We use a struct instead // of a lambda to make this explicit. struct AsyncRunTask { CpuExecutable* executable; diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.h b/tensorflow/compiler/xla/service/cpu/cpu_executable.h index 96e53de57eee013fe6f847c10e23a38f5beb9adc..3c3c047bfe8ee0d1ad90ede2432a86264f47870b 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.h @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/cpu/simple_orc_jit.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" @@ -33,7 +34,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/types.h" @@ -57,12 +57,12 @@ class CpuExecutable : public Executable { StatusOr ExecuteOnStream( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, HloExecutionProfile* hlo_execution_profile) override; StatusOr ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments) override; + absl::Span arguments) override; // This should be called after set_ir_module_string. const string& ir_module_string() const { return ir_module_string_; } @@ -74,9 +74,10 @@ class CpuExecutable : public Executable { static int64 ShapeSizeBytes(const Shape& shape); // Type of the computation function we expect in the JIT. - using ComputeFunctionType = void (*)( - void* /*result*/, const ExecutableRunOptions* /*run_options*/, - const void** /*args*/, void** /*temps*/, int64* /*profile_counters*/); + using ComputeFunctionType = + void (*)(void* /*result*/, const ExecutableRunOptions* /*run_options*/, + const void** /*args*/, void** /*buffer_table*/, + int64* /*profile_counters*/); const ComputeFunctionType& compute_function() const { return compute_function_; @@ -92,18 +93,18 @@ class CpuExecutable : public Executable { // exists) must out-live the task. StatusOr ExecuteAsyncOnStreamImpl( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, HloExecutionProfile* hlo_execution_profile); - // Creates an array suitable for passing as the "temps" argument to the JIT - // compiled function pointer. + // Creates an array suitable for passing as the "buffer_table" argument to the + // JIT compiled function pointer. // // Returns (unowning_buffers, owning_buffers) where: // - // - unowning_buffers.data() can be passed as the temps argument as-is and - // includes pointers to the scratch storage required by the computation, - // the live-out buffer into which the result will be written and entry - // computation parameters. + // - unowning_buffers.data() can be passed as the buffer_table argument as-is + // and includes pointers to the scratch storage required by the + // computation, the live-out buffer into which the result will be written + // and entry computation parameters. // // - owning_buffers contains owning pointers to the buffers that were // allocated by this routine. This routine allocates buffers for temporary @@ -111,22 +112,21 @@ class CpuExecutable : public Executable { // result. StatusOr, std::vector>> - CreateTempArray(DeviceMemoryAllocator* memory_allocator, int device_ordinal, - tensorflow::gtl::ArraySlice arguments); + CreateBufferTable(DeviceMemoryAllocator* memory_allocator, int device_ordinal, + absl::Span arguments); // Calls the generated function performing the computation with the given // arguments using the supplied buffers. - Status ExecuteComputeFunction( - const ExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice buffers, - HloExecutionProfile* hlo_execution_profile); + Status ExecuteComputeFunction(const ExecutableRunOptions* run_options, + absl::Span buffers, + HloExecutionProfile* hlo_execution_profile); // Creates a ScopedShapedBuffer for holding the result of the computation, // moving buffers out of allocated_buffers and into the result as appropriate. // The addresses are set according to buffer assignment. StatusOr CreateResultShapedBuffer( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::MutableArraySlice buffers); + absl::Span buffers); // Returns the points-to set of the root instruction of the entry // computation. Uses points-to analysis from buffer assignment. diff --git a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.cc b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.cc index 7bd4741a04b1135d9780e0cf765b7b33378526e1..7fbe0fa157c57eb0c274662a1de95cf5328ccfa8 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.cc @@ -34,9 +34,8 @@ StatusOr CpuHloSupportChecker::Run(HloModule* module) { return xla::Unimplemented( "CPU backend does not support HLO instruction %s with shape " "containing a sparse layout: %s", - instruction->ToString().c_str(), - ShapeUtil::HumanStringWithLayout(instruction->shape()) - .c_str()); + instruction->ToString(), + ShapeUtil::HumanStringWithLayout(instruction->shape())); } return Status::OK(); })); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h index 6af724b2a5d71b9c30f3485ffb7e51d1d201cb6b..a39a9d4724655370454d60fbb7b474f223bd8a85 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h @@ -23,7 +23,7 @@ namespace xla { // This pass should run early in the HLO pipeline and checks for HLO constructs // which are not supported by the CPU backend and cannot be removed via HLO // transformations (eg, sparse layouts). -class CpuHloSupportChecker : public HloPassInterface { +class CpuHloSupportChecker : public HloModulePass { public: CpuHloSupportChecker() = default; ~CpuHloSupportChecker() override = default; diff --git a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker_test.cc index 0f463e6de623fc6ab43d685ff2a5d6882ba7b8a2..e6b6fcdf684eadb3702e490bbe24dbb7b3b52ec7 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker_test.cc @@ -16,7 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/core/lib/core/error_codes.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -25,7 +25,7 @@ namespace { using ::testing::HasSubstr; -class CpuHloSupportCheckerTest : public HloTestBase { +class CpuHloSupportCheckerTest : public HloVerifiedTestBase { protected: CpuHloSupportChecker& checker() { return checker_; } @@ -45,7 +45,7 @@ TEST_F(CpuHloSupportCheckerTest, Add) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - TF_ASSERT_OK(checker().Run(module.get()).status()); + TF_ASSERT_OK(checker().Run(module).status()); } TEST_F(CpuHloSupportCheckerTest, SparseUnimplemented) { @@ -57,7 +57,10 @@ TEST_F(CpuHloSupportCheckerTest, SparseUnimplemented) { HloInstruction::CreateParameter(1, sparse_shape, "param1")); builder.AddInstruction(HloInstruction::CreateBinary( sparse_shape, HloOpcode::kAdd, param0, param1)); - auto module = CreateNewModule(); + // Since verifier is reporting sparse layouts as errors, we should + // use a regular HloModule instead of VerifiedHloModule to avoid + // verifier errors being triggered in the destructor. + auto module = HloTestBase::CreateNewModule(); module->AddEntryComputation(builder.Build()); Status status = checker().Run(module.get()).status(); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc index b40d264c03aba6e9308e8a621ae86e180e33c335..f9cd61bea3dc86cadff99d4a90eca44c16520823 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc @@ -35,7 +35,7 @@ bool CanBeLoopFused(const HloInstruction& hlo) { hlo.opcode() == HloOpcode::kDynamicSlice || hlo.opcode() == HloOpcode::kDynamicUpdateSlice || hlo.opcode() == HloOpcode::kGather || - hlo.opcode() == HloOpcode::kPad || + hlo.opcode() == HloOpcode::kIota || hlo.opcode() == HloOpcode::kPad || hlo.opcode() == HloOpcode::kReshape || hlo.opcode() == HloOpcode::kReverse || hlo.opcode() == HloOpcode::kSlice || @@ -78,7 +78,7 @@ bool CpuInstructionFusion::ShouldFuse(HloInstruction* consumer, } if (!CanBeLoopFused(*producer)) { - VLOG(2) << "Producer is not fusile."; + VLOG(2) << "Producer is not fusible."; return false; } @@ -140,7 +140,7 @@ bool CpuInstructionFusion::ShouldFuse(HloInstruction* consumer, } if (CanBeLoopFused(*consumer)) { - VLOG(2) << "Fusing: consumer is elementwise or fusile."; + VLOG(2) << "Fusing: consumer is elementwise or fusible."; return true; } diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc index c3e03056f0f5526932de74efbd0433919d63aba1..7d99b914d4f5e5d27722bcd098d2ae0c54a36a23 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc @@ -19,11 +19,12 @@ limitations under the License. #include #include "absl/strings/str_cat.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/transpose_folding.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" -#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" namespace op = xla::testing::opcode_matchers; @@ -38,7 +39,11 @@ std::unique_ptr MakeDot(const Shape& shape, HloInstruction* lhs, DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - return HloInstruction::CreateDot(shape, lhs, rhs, dot_dnums); + PrecisionConfig precision_config; + precision_config.mutable_operand_precision()->Resize( + 2, PrecisionConfig::DEFAULT); + return HloInstruction::CreateDot(shape, lhs, rhs, dot_dnums, + precision_config); } TEST_F(InstructionFusionTest, DotOperationFusion_Basic_0) { @@ -567,7 +572,7 @@ TEST_F(OpcodeFusionTest, DynamicSliceWithDynamicUpdateSlice) { HloOpcode::kParameter, HloOpcode::kParameter}); } -TEST_F(OpcodeFusionTest, MessOfFusileNodes) { +TEST_F(OpcodeFusionTest, MessOfFusibleNodes) { auto module = CreateNewModule(); HloComputation::Builder builder(TestName()); @@ -692,8 +697,8 @@ void CreateComputationForDotAddOutputFusionTest(const string& test_name, auto* addend = builder.AddInstruction( HloInstruction::CreateParameter(2, dot_shape, "param2")); - auto* dot = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(dot_shape, dot_lhs, dot_rhs)); + auto* dot = + builder.AddInstruction(CreateCanonicalDot(dot_shape, dot_lhs, dot_rhs)); builder.AddInstruction( HloInstruction::CreateBinary(dot_shape, HloOpcode::kAdd, dot, addend)); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc index bfecbd6e017893e4f6d3dcbc01d46c899e6060fa..c291bf2d1ba2eaff4192051840768c037bece86f 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "absl/container/flat_hash_map.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h" #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" @@ -38,7 +39,7 @@ using absl::nullopt; using absl::optional; using ShouldMakeOperandColMajorCache = - tensorflow::gtl::FlatMap; + absl::flat_hash_map; } // namespace static bool ShouldMakeAllUsersColMajor(const HloInstruction* instruction) { diff --git a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h index 3c4fe68b830d9602f009b318d4e51e9a04a27e09..f4da35dd373f24d81323d198582048e2e6d36268 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h @@ -30,8 +30,11 @@ class CpuLayoutAssignment : public LayoutAssignment { public: explicit CpuLayoutAssignment( ComputationLayout* entry_computation_layout, + std::function + instruction_can_change_layout_func, const TargetMachineFeatures* target_machine_features) - : LayoutAssignment(entry_computation_layout), + : LayoutAssignment(entry_computation_layout, + std::move(instruction_can_change_layout_func)), target_machine_features_(*target_machine_features) {} ~CpuLayoutAssignment() override {} diff --git a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc index 3681d12d8da818d06d2f690024008c9ccb896286..97659b88a7974d7caf91ab0d4741f3635e4dae4a 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/algebraic_simplifier.h" @@ -39,7 +40,6 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/array_slice.h" namespace op = xla::testing::opcode_matchers; @@ -54,8 +54,9 @@ class CpuLayoutAssignmentTest : public HloTestBase { [](int64 shape_size) { return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment; }); - cpu::CpuLayoutAssignment layout_assignment(entry_computation_layout, - &target_machine_features); + cpu::CpuLayoutAssignment layout_assignment( + entry_computation_layout, LayoutAssignment::InstructionCanChangeLayout, + &target_machine_features); EXPECT_IS_OK(layout_assignment.Run(module).status()); } }; @@ -70,7 +71,7 @@ TEST_F(CpuLayoutAssignmentTest, DotWithConstantRhsTensor) { auto dot_rhs = builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateFromShape(rhs_shape))); auto result = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(result_shape, dot_lhs, dot_rhs)); + CreateCanonicalDot(result_shape, dot_lhs, dot_rhs)); auto module = CreateNewModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); @@ -107,9 +108,9 @@ TEST_F(CpuLayoutAssignmentTest, MultipleDotsWithSameConstantRhsTensor0) { auto dot_rhs = builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateFromShape(rhs_shape))); auto dot_a_result = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(result_shape, dot_a_lhs, dot_rhs)); + CreateCanonicalDot(result_shape, dot_a_lhs, dot_rhs)); auto dot_b_result = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(result_shape, dot_b_lhs, dot_rhs)); + CreateCanonicalDot(result_shape, dot_b_lhs, dot_rhs)); builder.AddInstruction(HloInstruction::CreateBinary( result_shape, HloOpcode::kAdd, dot_a_result, dot_b_result)); @@ -151,9 +152,9 @@ TEST_F(CpuLayoutAssignmentTest, MultipleDotsWithSameConstantRhsTensor1) { auto dot_rhs = builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateFromShape(rhs_shape))); auto dot_a_result = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(result_a_shape, dot_a_lhs, dot_rhs)); + CreateCanonicalDot(result_a_shape, dot_a_lhs, dot_rhs)); auto dot_b_result = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(result_b_shape, dot_b_lhs, dot_rhs)); + CreateCanonicalDot(result_b_shape, dot_b_lhs, dot_rhs)); auto tuple_result = builder.AddInstruction( HloInstruction::CreateTuple({dot_a_result, dot_b_result})); @@ -189,7 +190,7 @@ TEST_F(CpuLayoutAssignmentTest, DotWithConstantLhsTensor) { auto dot_rhs = builder.AddInstruction( HloInstruction::CreateParameter(0, rhs_shape, "param0")); auto dot_result = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(result_shape, dot_lhs, dot_rhs)); + CreateCanonicalDot(result_shape, dot_lhs, dot_rhs)); auto module = CreateNewModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); @@ -229,7 +230,7 @@ TEST_F(CpuLayoutAssignmentTest, DotWithConstantRhsTensorThroughGTE) { auto dot_rhs = builder.AddInstruction( HloInstruction::CreateGetTupleElement(rhs_shape, constant, 1)); auto dot_result = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(result_shape, dot_lhs, dot_rhs)); + CreateCanonicalDot(result_shape, dot_lhs, dot_rhs)); auto module = CreateNewModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); @@ -276,8 +277,8 @@ static StatusOr RunDotOutputFusion( HloInstruction::CreateParameter(1, dot_shape, "param1")); HloInstruction* dot_rhs = builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateFromShape(dot_rhs_shape))); - HloInstruction* dot_result = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(dot_shape, dot_lhs, dot_rhs)); + HloInstruction* dot_result = + builder.AddInstruction(CreateCanonicalDot(dot_shape, dot_lhs, dot_rhs)); HloInstruction* add_result; if (dot_operand_idx_in_add == 0) { add_result = builder.AddInstruction(HloInstruction::CreateBinary( @@ -321,8 +322,9 @@ static StatusOr RunDotOutputFusion( [](int64 shape_size) { return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment; }); - cpu::CpuLayoutAssignment layout_assignment(&computation_layout, - &target_machine_features); + cpu::CpuLayoutAssignment layout_assignment( + &computation_layout, LayoutAssignment::InstructionCanChangeLayout, + &target_machine_features); TF_ASSIGN_OR_RETURN(result.layout_assignment_changed_something, layout_assignment.Run(module)); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc index 639064040f521a9e84bd87c5d05f674204e4d6e2..a9febe891b5e9d1eb9e6b297952b50d1d26a3396 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc @@ -17,18 +17,29 @@ limitations under the License. #include +#include "absl/container/flat_hash_map.h" +#include "absl/synchronization/mutex.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" +#include "tensorflow/core/platform/dynamic_annotations.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" +#include "tensorflow/stream_executor/stream_executor.h" namespace xla { namespace cpu { namespace runtime { -XfeedManager* GetXfeedManager() { - static XfeedManager* manager = new XfeedManager; - return manager; +XfeedManager* GetXfeedManager(int device_ordinal) { + static auto* managers = new absl::flat_hash_map(); + static absl::Mutex* mutex = new absl::Mutex(); + + absl::MutexLock lock(mutex); + auto it = managers->find(device_ordinal); + if (it == managers->end()) { + it = managers->emplace(device_ordinal, new XfeedManager()).first; + } + return it->second; } extern const char* const kEigenMatMulF16SymbolName = @@ -73,6 +84,30 @@ extern const char* const kReleaseOutfeedBufferAfterPopulationSymbolName = "__xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation"; extern const char* const kParallelForkJoinSymbolName = "__xla_cpu_runtime_ParallelForkJoin"; +extern const char* const kKeyValueSortPREDSymbolName = + "__xla_cpu_runtime_KeyValueSortPRED"; +extern const char* const kKeyValueSortS8SymbolName = + "__xla_cpu_runtime_KeyValueSortS8"; +extern const char* const kKeyValueSortU8SymbolName = + "__xla_cpu_runtime_KeyValueSortU8"; +extern const char* const kKeyValueSortS16SymbolName = + "__xla_cpu_runtime_KeyValueSortS16"; +extern const char* const kKeyValueSortU16SymbolName = + "__xla_cpu_runtime_KeyValueSortU16"; +extern const char* const kKeyValueSortF16SymbolName = + "__xla_cpu_runtime_KeyValueSortF16"; +extern const char* const kKeyValueSortS32SymbolName = + "__xla_cpu_runtime_KeyValueSortS32"; +extern const char* const kKeyValueSortU32SymbolName = + "__xla_cpu_runtime_KeyValueSortU32"; +extern const char* const kKeyValueSortF32SymbolName = + "__xla_cpu_runtime_KeyValueSortF32"; +extern const char* const kKeyValueSortS64SymbolName = + "__xla_cpu_runtime_KeyValueSortS64"; +extern const char* const kKeyValueSortU64SymbolName = + "__xla_cpu_runtime_KeyValueSortU64"; +extern const char* const kKeyValueSortF64SymbolName = + "__xla_cpu_runtime_KeyValueSortF64"; extern const char* const kXlaCpuRuntimeSymbolNamePrefix = "__xla_cpu_runtime_"; } // namespace runtime @@ -93,14 +128,18 @@ tensorflow::string ShapeString(const void* shape_ptr, xla::int32 shape_length) { } // namespace TF_ATTRIBUTE_NO_SANITIZE_MEMORY void* -__xla_cpu_runtime_AcquireInfeedBufferForDequeue(xla::int32 buffer_length, - const void* shape, - xla::int32 shape_length) { - if (VLOG_IS_ON(2)) { - LOG(INFO) << "AcquireInfeedBufferForDequeue: " - << ShapeString(shape, shape_length); - } - xla::cpu::runtime::XfeedManager* xfeed = xla::cpu::runtime::GetXfeedManager(); +__xla_cpu_runtime_AcquireInfeedBufferForDequeue( + const xla::ExecutableRunOptions* run_options, xla::int32 buffer_length, + const void* shape, xla::int32 shape_length) { + int device_ordinal = + run_options ? run_options->stream()->parent()->device_ordinal() : 0; + + VLOG(2) << "AcquireInfeedBufferForDequeue: " + << ShapeString(shape, shape_length) << " on stream executor " + << device_ordinal; + + xla::cpu::runtime::XfeedManager* xfeed = + xla::cpu::runtime::GetXfeedManager(device_ordinal); // Wait until there's a buffer to dequeue. xla::cpu::runtime::XfeedBuffer* buffer = xfeed->infeed()->BlockingDequeueBuffer(); @@ -113,15 +152,18 @@ __xla_cpu_runtime_AcquireInfeedBufferForDequeue(xla::int32 buffer_length, } TF_ATTRIBUTE_NO_SANITIZE_MEMORY void -__xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue(xla::int32 buffer_length, - void* buffer_ptr, - const void* shape_ptr, - xla::int32 shape_length) { - if (VLOG_IS_ON(2)) { - LOG(INFO) << "ReleaseInfeedBufferAfterDeque: " - << ShapeString(shape_ptr, shape_length); - } - xla::cpu::runtime::XfeedManager* xfeed = xla::cpu::runtime::GetXfeedManager(); +__xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue( + const xla::ExecutableRunOptions* run_options, xla::int32 buffer_length, + void* buffer_ptr, const void* shape_ptr, xla::int32 shape_length) { + int device_ordinal = + run_options ? run_options->stream()->parent()->device_ordinal() : 0; + + VLOG(2) << "ReleaseInfeedBufferAfterDeque: " + << ShapeString(shape_ptr, shape_length) << " on stream executor " + << device_ordinal; + + xla::cpu::runtime::XfeedManager* xfeed = + xla::cpu::runtime::GetXfeedManager(device_ordinal); xla::StatusOr shape = xla::llvm_ir::DecodeSelfDescribingShapeConstant(shape_ptr, shape_length); xfeed->infeed()->ReleaseCurrentBuffer(buffer_length, buffer_ptr, @@ -129,14 +171,18 @@ __xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue(xla::int32 buffer_length, } TF_ATTRIBUTE_NO_SANITIZE_MEMORY void* -__xla_cpu_runtime_AcquireOutfeedBufferForPopulation(xla::int32 buffer_length, - const void* shape_ptr, - xla::int32 shape_length) { - if (VLOG_IS_ON(2)) { - LOG(INFO) << "AcquireOutfeedBufferForPopulation: " - << ShapeString(shape_ptr, shape_length); - } - xla::cpu::runtime::XfeedManager* xfeed = xla::cpu::runtime::GetXfeedManager(); +__xla_cpu_runtime_AcquireOutfeedBufferForPopulation( + const xla::ExecutableRunOptions* run_options, xla::int32 buffer_length, + const void* shape_ptr, xla::int32 shape_length) { + int device_ordinal = + run_options ? run_options->stream()->parent()->device_ordinal() : 0; + + VLOG(2) << "AcquireOutfeedBufferForPopulation: " + << ShapeString(shape_ptr, shape_length) << " on stream executor " + << device_ordinal; + + xla::cpu::runtime::XfeedManager* xfeed = + xla::cpu::runtime::GetXfeedManager(device_ordinal); // Wait until there's a buffer to dequeue. xla::cpu::runtime::XfeedBuffer* buffer = xfeed->outfeed()->BlockingDequeueBuffer(); @@ -149,15 +195,18 @@ __xla_cpu_runtime_AcquireOutfeedBufferForPopulation(xla::int32 buffer_length, } TF_ATTRIBUTE_NO_SANITIZE_MEMORY void -__xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation(xla::int32 buffer_length, - void* buffer_ptr, - const void* shape_ptr, - xla::int32 shape_length) { - if (VLOG_IS_ON(2)) { - LOG(INFO) << "ReleaseOutfeedBufferAfterPopulation: " - << ShapeString(shape_ptr, shape_length); - } - xla::cpu::runtime::XfeedManager* xfeed = xla::cpu::runtime::GetXfeedManager(); +__xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation( + const xla::ExecutableRunOptions* run_options, xla::int32 buffer_length, + void* buffer_ptr, const void* shape_ptr, xla::int32 shape_length) { + int device_ordinal = + run_options ? run_options->stream()->parent()->device_ordinal() : 0; + + VLOG(2) << "ReleaseOutfeedBufferAfterPopulation: " + << ShapeString(shape_ptr, shape_length) << " on stream executor " + << device_ordinal; + + xla::cpu::runtime::XfeedManager* xfeed = + xla::cpu::runtime::GetXfeedManager(device_ordinal); xla::StatusOr shape = xla::llvm_ir::DecodeSelfDescribingShapeConstant(shape_ptr, shape_length); xfeed->outfeed()->ReleaseCurrentBuffer(buffer_length, buffer_ptr, diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h index aa0e96712302e806a389c6ad05a2c1b6634ef901..b2e760a224ad8eaa61dae57b0f9cece04a7e54ae 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h @@ -26,6 +26,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_H_ +#include "tensorflow/compiler/xla/executable_run_options.h" #include "tensorflow/compiler/xla/service/cpu/xfeed_manager.h" #include "tensorflow/compiler/xla/types.h" @@ -63,13 +64,26 @@ extern const char* const kReleaseInfeedBufferAfterDequeueSymbolName; extern const char* const kAcquireOutfeedBufferForPopulationSymbolName; extern const char* const kReleaseOutfeedBufferAfterPopulationSymbolName; extern const char* const kParallelForkJoinSymbolName; +extern const char* const kKeyValueSortPREDSymbolName; +extern const char* const kKeyValueSortS8SymbolName; +extern const char* const kKeyValueSortU8SymbolName; +extern const char* const kKeyValueSortS16SymbolName; +extern const char* const kKeyValueSortU16SymbolName; +extern const char* const kKeyValueSortF16SymbolName; +extern const char* const kKeyValueSortS32SymbolName; +extern const char* const kKeyValueSortU32SymbolName; +extern const char* const kKeyValueSortF32SymbolName; +extern const char* const kKeyValueSortS64SymbolName; +extern const char* const kKeyValueSortU64SymbolName; +extern const char* const kKeyValueSortF64SymbolName; // All symbol names for XLA CPU runtime functions need to start with this // prefix. extern const char* const kXlaCpuRuntimeSymbolNamePrefix; -// Returns the infeed manager used by the CPU runtime. -XfeedManager* GetXfeedManager(); +// Returns the infeed manager used by the CPU runtime for the CPU device +// `device_ordinal`. Note the device ordinal does not name a CPU +XfeedManager* GetXfeedManager(int device_ordinal); } // namespace runtime } // namespace cpu @@ -77,6 +91,18 @@ XfeedManager* GetXfeedManager(); extern "C" { +// Some things common to all of the runtime entry points below: +// +// * The shape pointer and shape_length reflect values that can be deserialized +// via llvm_ir::DecodeSelfDescribingShapeConstant. This is the way we pass +// reified type information from the generated program to the runtime, which +// helps check the type safety and contract for the emitted-code/runtime +// communication. +// +// * run_options is used to look up the device ordinal for the stream executor +// we're executing under. If it is null the device ordinal is assumed to be +// 0 (this behavior helps in writing tests). + // Note: in the runtime entry points below, the shape pointer and shape_length // reflect values that can be deserialized via // llvm_ir::DecodeSelfDescribingShapeConstant. This is the way we pass reified @@ -89,7 +115,8 @@ extern "C" { // the length would be more exact, but the length check is chosen as a // tradeoff between error checking and speed/simplicity. extern void* __xla_cpu_runtime_AcquireInfeedBufferForDequeue( - xla::int32 buffer_length, const void* shape, xla::int32 shape_length); + const xla::ExecutableRunOptions* run_options, xla::int32 buffer_length, + const void* shape, xla::int32 shape_length); // Relinquishes the next infeed buffer that was returned by // __xla_cpu_runtime_AcquireInfeedBufferForDequeue. Once this call @@ -104,13 +131,14 @@ extern void* __xla_cpu_runtime_AcquireInfeedBufferForDequeue( // implemented we will add support for multiple outstanding buffers // that can be returned out of order. extern void __xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue( - xla::int32 buffer_length, void* buffer_ptr, const void* shape_ptr, - xla::int32 shape_length); + const xla::ExecutableRunOptions* run_options, xla::int32 buffer_length, + void* buffer_ptr, const void* shape_ptr, xla::int32 shape_length); // Blocks until the next outfeed buffer is available to be populated, then // returns it. extern void* __xla_cpu_runtime_AcquireOutfeedBufferForPopulation( - xla::int32 buffer_length, const void* shape_ptr, xla::int32 shape_length); + const xla::ExecutableRunOptions* run_options, xla::int32 buffer_length, + const void* shape_ptr, xla::int32 shape_length); // Relinquishes the outfeed buffer after it has been populated. // buffer_ptr must have been previously returned by @@ -122,8 +150,8 @@ extern void* __xla_cpu_runtime_AcquireOutfeedBufferForPopulation( // acquired, i.e., there may only be one outstanding outfeed buffer in // use by the runtime. extern void __xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation( - xla::int32 buffer_length, void* buffer_ptr, const void* shape_ptr, - xla::int32 shape_length); + const xla::ExecutableRunOptions* run_options, xla::int32 buffer_length, + void* buffer_ptr, const void* shape_ptr, xla::int32 shape_length); } // extern "C" diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc index bc4cfc099965e2ab12212f55e62bdf79c0cfb739..1ae3aa57111e3a3b7ac18b4907c5c282edf89b7e 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include "absl/memory/memory.h" +#include "absl/strings/str_format.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/local_client.h" @@ -28,7 +29,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/common_runtime/eigen_thread_pool.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/test.h" @@ -142,10 +142,10 @@ class EigenMatMulTest : public CpuRuntimeTest, bool transpose_rhs = std::get<2>(info.param); bool single_threaded = std::get<3>(info.param); - return tensorflow::strings::Printf( - "EigenMatMul_%lld_%lld_%lld_%s%s%s_threaded", shape.m, shape.k, shape.n, - transpose_lhs ? "Tlhs_" : "", transpose_rhs ? "Trhs_" : "", - single_threaded ? "single" : "multi"); + return absl::StrFormat("EigenMatMul_%d_%d_%d_%s%s%s_threaded", shape.m, + shape.k, shape.n, transpose_lhs ? "Tlhs_" : "", + transpose_rhs ? "Trhs_" : "", + single_threaded ? "single" : "multi"); } }; @@ -178,10 +178,10 @@ class MKLMatMulTest : public CpuRuntimeTest, bool transpose_rhs = std::get<2>(info.param); bool single_threaded = std::get<3>(info.param); - return tensorflow::strings::Printf( - "MKLMatMul_%lld_%lld_%lld_%s%s%s_threaded", shape.m, shape.k, shape.n, - transpose_lhs ? "Tlhs_" : "", transpose_rhs ? "Trhs_" : "", - single_threaded ? "single" : "multi"); + return absl::StrFormat("MKLMatMul_%d_%d_%d_%s%s%s_threaded", shape.m, + shape.k, shape.n, transpose_lhs ? "Tlhs_" : "", + transpose_rhs ? "Trhs_" : "", + single_threaded ? "single" : "multi"); } }; diff --git a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc index b07cd675ffc4dbd0c7d56da715b29014bb12ce88..1cc2844470376ceb61601f6d1361def84eac5b45 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc @@ -35,6 +35,7 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/notification.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/stream_executor/stream_executor.h" namespace xla { @@ -104,7 +105,7 @@ Status CpuTransferManager::TransferLiteralToInfeed( if (ShapeUtil::IsNestedTuple(shape)) { return Unimplemented( "Infeed with a nested tuple shape is not supported: %s", - ShapeUtil::HumanString(literal.shape()).c_str()); + ShapeUtil::HumanString(literal.shape())); } // For a tuple, we transfer each of its elements to the device and @@ -128,7 +129,8 @@ Status CpuTransferManager::TransferLiteralToInfeed( buffers.push_back(buffer); } - cpu::runtime::XfeedManager* xfeed_manager = cpu::runtime::GetXfeedManager(); + cpu::runtime::XfeedManager* xfeed_manager = + cpu::runtime::GetXfeedManager(executor->device_ordinal()); xfeed_manager->infeed()->EnqueueBuffersAtomically(buffers); cleanup.release(); @@ -141,7 +143,8 @@ Status CpuTransferManager::TransferBufferToInfeed(se::StreamExecutor* executor, TF_ASSIGN_OR_RETURN(cpu::runtime::XfeedBuffer * buffer, TransferBufferToInfeedInternal(executor, size, source)); - cpu::runtime::XfeedManager* xfeed_manager = cpu::runtime::GetXfeedManager(); + cpu::runtime::XfeedManager* xfeed_manager = + cpu::runtime::GetXfeedManager(executor->device_ordinal()); xfeed_manager->infeed()->EnqueueBuffersAtomically({buffer}); return Status::OK(); @@ -152,11 +155,11 @@ CpuTransferManager::TransferBufferToInfeedInternal(se::StreamExecutor* executor, int64 size, const void* source) { if (size > std::numeric_limits::max()) { - return InvalidArgument("Infeed shape is too large: needs %lld bytes", size); + return InvalidArgument("Infeed shape is too large: needs %d bytes", size); } if (size <= 0) { - return InvalidArgument("Infeed shape must have positive size; got %lld", + return InvalidArgument("Infeed shape must have positive size; got %d", size); } @@ -179,7 +182,7 @@ Status CpuTransferManager::TransferLiteralFromOutfeed( int64 size = GetByteSizeRequirement(literal_shape); // Note: OSS build didn't like implicit conversion from // literal_shape.dimensions() to the array slice on 2017-07-10. - tensorflow::gtl::ArraySlice dimensions( + absl::Span dimensions( tensorflow::bit_cast(literal_shape.dimensions().data()), literal_shape.dimensions().size()); TF_ASSIGN_OR_RETURN( @@ -225,7 +228,7 @@ Status CpuTransferManager::TransferLiteralFromOutfeed( StatusOr CpuTransferManager::TransferTupleBuffersFromOutfeed( se::StreamExecutor* executor, - tensorflow::gtl::ArraySlice> buffer_data) { + absl::Span> buffer_data) { return TransferBuffersFromOutfeedInternal(executor, buffer_data, /*is_tuple=*/true); } @@ -238,18 +241,17 @@ StatusOr CpuTransferManager::TransferArrayBufferFromOutfeed( StatusOr CpuTransferManager::TransferBuffersFromOutfeedInternal( se::StreamExecutor* executor, - tensorflow::gtl::ArraySlice> buffer_data, - bool is_tuple) { + absl::Span> buffer_data, bool is_tuple) { std::vector> buffers; for (auto b : buffer_data) { int64 size = b.second; if (size > std::numeric_limits::max()) { - return InvalidArgument("Outfeed shape is too large: needs %lld bytes", + return InvalidArgument("Outfeed shape is too large: needs %d bytes", size); } if (size <= 0) { - return InvalidArgument("Outfeed shape must have positive size; got %lld", + return InvalidArgument("Outfeed shape must have positive size; got %d", size); } @@ -266,7 +268,8 @@ StatusOr CpuTransferManager::TransferBuffersFromOutfeedInternal( buffer_pointers.push_back(b.get()); } - cpu::runtime::XfeedManager* xfeed_manager = cpu::runtime::GetXfeedManager(); + cpu::runtime::XfeedManager* xfeed_manager = + cpu::runtime::GetXfeedManager(executor->device_ordinal()); xfeed_manager->outfeed()->EnqueueBuffersAtomically(buffer_pointers); VLOG(2) << "Waiting for buffer to be notified as populated."; std::vector outfed_shapes; diff --git a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h index 7b938e9fd7d59109c7ffec4fc67c1d2ee50ea65f..361d4b9c8422fff6afe53e56e0bb10a484c9becc 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h @@ -18,13 +18,13 @@ limitations under the License. #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/cpu/xfeed_manager.h" #include "tensorflow/compiler/xla/service/generic_transfer_manager.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/types.h" @@ -56,7 +56,7 @@ class CpuTransferManager : public GenericTransferManager { // Helper that transfers a tuple of element buffers from the device's outfeed. StatusOr TransferTupleBuffersFromOutfeed( se::StreamExecutor* executor, - tensorflow::gtl::ArraySlice> buffer_data); + absl::Span> buffer_data); // Helper that transfers an array buffer from the device's outfeed. StatusOr TransferArrayBufferFromOutfeed(se::StreamExecutor* executor, @@ -68,8 +68,7 @@ class CpuTransferManager : public GenericTransferManager { // for the given buffers. StatusOr TransferBuffersFromOutfeedInternal( se::StreamExecutor* executor, - tensorflow::gtl::ArraySlice> buffer_data, - bool is_tuple); + absl::Span> buffer_data, bool is_tuple); TF_DISALLOW_COPY_AND_ASSIGN(CpuTransferManager); }; diff --git a/tensorflow/compiler/xla/service/cpu/disassembler.cc b/tensorflow/compiler/xla/service/cpu/disassembler.cc index e4c674e227ffc6725ca929f720b9aa7cf7c4c032..3ae64142cd7e32d3aa8d50870efaf94698c06440 100644 --- a/tensorflow/compiler/xla/service/cpu/disassembler.cc +++ b/tensorflow/compiler/xla/service/cpu/disassembler.cc @@ -21,13 +21,13 @@ limitations under the License. #include #include +#include "absl/strings/str_format.h" #include "llvm/MC/MCInst.h" #include "llvm/Support/TargetRegistry.h" #include "llvm/Support/raw_ostream.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -151,7 +151,7 @@ StatusOr Disassembler::DisassembleObjectFile( size = 1; } - ostream << tensorflow::strings::Printf("0x%08lx", index) << " "; + ostream << absl::StrFormat("0x%08lx", index) << " "; if (decode_status == llvm::MCDisassembler::Success) { // For branches, try to determine the actual address and emit it as an @@ -163,7 +163,7 @@ StatusOr Disassembler::DisassembleObjectFile( uint64_t target; if (inst_analysis_->evaluateBranch( instruction, section_address + index, size, target)) { - annotation = tensorflow::strings::Printf("[0x%08lx]", target); + annotation = absl::StrFormat("[0x%08lx]", target); } } inst_printer_->printInst(&instruction, ostream, annotation.c_str(), diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc index 4af16f4fa0817df8a117b7852a8e5a2ef611e1c9..99fa707c959854e50c6d954fe92b87e93e267dc6 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc @@ -80,7 +80,7 @@ class MemoryTile { // `minor_dim_offset`}. // // Note: `major_dim_offset` is a parameter to the constructor. - void StoreTile(tensorflow::gtl::ArraySlice tile, + void StoreTile(absl::Span tile, llvm::Value* minor_dim_offset) const { CHECK_EQ(tile.size(), pointers_.size()); for (int64 i = 0; i < pointers_.size(); i++) { @@ -1467,7 +1467,7 @@ Status DotOpEmitter::EmitCallToRuntime() { break; default: return Unimplemented("Invalid type %s for dot operation", - PrimitiveType_Name(type).c_str()); + PrimitiveType_Name(type)); } llvm::Type* float_ptr_type = float_type->getPointerTo(); diff --git a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc index db54454707983ade31594119b2e868fa168d4cc2..c8312d80bd5012e5bcb42a410db18a7fa77a2eb6 100644 --- a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc @@ -30,15 +30,16 @@ limitations under the License. namespace xla { namespace cpu { -StatusOr CpuElementalIrEmitter::EmitAtan2( - PrimitiveType prim_type, llvm::Value* lhs, llvm::Value* rhs) const { +StatusOr CpuElementalIrEmitter::EmitAtan2(PrimitiveType prim_type, + llvm::Value* lhs, + llvm::Value* rhs) { string function_name; bool cast_result_to_fp16 = false; switch (prim_type) { case F16: cast_result_to_fp16 = true; - lhs = b_->CreateFPCast(lhs, b_->getFloatTy()); - rhs = b_->CreateFPCast(rhs, b_->getFloatTy()); + lhs = FPCast(lhs, b_->getFloatTy()); + rhs = FPCast(rhs, b_->getFloatTy()); TF_FALLTHROUGH_INTENDED; case F32: function_name = "atan2f"; @@ -58,21 +59,21 @@ StatusOr CpuElementalIrEmitter::EmitAtan2( function->setDoesNotThrow(); function->setDoesNotAccessMemory(); // Create an instruction to call the function. - llvm::Value* result = b_->CreateCall(function, {lhs, rhs}); + llvm::Value* result = Call(function, {lhs, rhs}); if (cast_result_to_fp16) { - result = b_->CreateFPCast(result, b_->getHalfTy()); + result = FPCast(result, b_->getHalfTy()); } return result; } -StatusOr CpuElementalIrEmitter::EmitTanh( - PrimitiveType prim_type, llvm::Value* value) const { +StatusOr CpuElementalIrEmitter::EmitTanh(PrimitiveType prim_type, + llvm::Value* value) { bool cast_result_to_fp16 = false; string function_name; switch (prim_type) { case F16: cast_result_to_fp16 = true; - value = b_->CreateFPCast(value, b_->getFloatTy()); + value = FPCast(value, b_->getFloatTy()); TF_FALLTHROUGH_INTENDED; case F32: function_name = "tanhf"; @@ -91,16 +92,16 @@ StatusOr CpuElementalIrEmitter::EmitTanh( function->setDoesNotThrow(); function->setDoesNotAccessMemory(); // Create an instruction to call the function. - llvm::Value* result = b_->CreateCall(function, value); + llvm::Value* result = Call(function, value); if (cast_result_to_fp16) { - result = b_->CreateFPCast(result, b_->getHalfTy()); + result = FPCast(result, b_->getHalfTy()); } return result; } llvm_ir::ElementGenerator CpuElementalIrEmitter::MakeElementGenerator( const HloInstruction* hlo, - const HloToElementGeneratorMap& operand_to_generator) const { + const HloToElementGeneratorMap& operand_to_generator) { if (hlo->opcode() == HloOpcode::kMap) { return [this, hlo, &operand_to_generator]( const llvm_ir::IrArray::Index& index) -> StatusOr { diff --git a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h index 76833e765d05f2477961cd06cead66797c5be623..e3fba9306b72904803259047fafea245a8e183db 100644 --- a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h @@ -36,13 +36,13 @@ class CpuElementalIrEmitter : public ElementalIrEmitter { llvm_ir::ElementGenerator MakeElementGenerator( const HloInstruction* hlo, - const HloToElementGeneratorMap& operand_to_generator) const override; + const HloToElementGeneratorMap& operand_to_generator) override; protected: StatusOr EmitAtan2(PrimitiveType prim_type, llvm::Value* lhs, - llvm::Value* rhs) const override; + llvm::Value* rhs) override; StatusOr EmitTanh(PrimitiveType prim_type, - llvm::Value* value) const override; + llvm::Value* value) override; IrEmitter* ir_emitter_; }; diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index 417a1dba1f8593ac5d234838b9aba7879899e02e..b2abdb39a598871a7cc44760e464f48b9a200874 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -24,10 +24,14 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "tensorflow/core/lib/math/math_util.h" #include "tensorflow/core/platform/logging.h" // IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc" #include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/types/span.h" #include "llvm/CodeGen/TargetRegisterInfo.h" #include "llvm/CodeGen/TargetSubtargetInfo.h" #include "llvm/IR/BasicBlock.h" @@ -65,10 +69,6 @@ limitations under the License. #include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/core/lib/core/bits.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/lib/gtl/flatmap.h" -#include "tensorflow/core/lib/gtl/flatset.h" -#include "tensorflow/core/lib/strings/stringprintf.h" namespace xla { @@ -100,12 +100,17 @@ IrEmitter::IrEmitter( b_.setFastMathFlags(llvm_ir::GetFastMathFlags( /*fast_math_enabled=*/hlo_module_config_.debug_options() .xla_cpu_enable_fast_math())); + Status s = GatherComputationsByAllocationType( + &hlo_module, &thread_local_computations_, &global_computations_); + absl::c_sort(thread_local_computations_); + absl::c_sort(global_computations_); + TF_CHECK_OK(s) << "Should have failed buffer assignment."; } StatusOr IrEmitter::EmitComputation( HloComputation* computation, const string& function_name_prefix, bool is_top_level_computation, - std::vector* instruction_order) { + const std::vector* instruction_order) { string function_name = name_uniquer_.GetUniqueName(function_name_prefix); VLOG(2) << "Emitting IR for CPU function [" << function_name_prefix << "]; ordered? " << (instruction_order != nullptr); @@ -170,9 +175,9 @@ IrEmitter::~IrEmitter() {} Status IrEmitter::HandleBitcast(HloInstruction* bitcast) { VLOG(2) << "HandleBitcast: " << bitcast->ToString(); emitted_value_[bitcast] = - b_.CreateBitCast(GetEmittedValueFor(bitcast->operand(0)), - IrShapeType(bitcast->shape())->getPointerTo(), - AsStringRef(IrName(bitcast))); + BitCast(GetEmittedValueFor(bitcast->operand(0)), + IrShapeType(bitcast->shape())->getPointerTo(), + AsStringRef(IrName(bitcast))); return Status::OK(); } @@ -230,9 +235,8 @@ Status IrEmitter::HandleCopy(HloInstruction* copy) { // Use the elemental emitter for array shapes. return DefaultAction(copy); } - return Unimplemented( - "unsupported operand type %s for copy instruction", - PrimitiveType_Name(copy->shape().element_type()).c_str()); + return Unimplemented("unsupported operand type %s for copy instruction", + PrimitiveType_Name(copy->shape().element_type())); } // Calculate the alignment of a buffer allocated for a given primitive type. @@ -338,10 +342,10 @@ Status IrEmitter::HandleInfeed(HloInstruction* instruction) { // Write the tuple index table. TF_ASSIGN_OR_RETURN(BufferAllocation::Slice data_slice, assignment_.GetUniqueSlice(infeed, {0})); - llvm::Value* data_address = EmitTempBufferPointer(data_slice, data_shape); + llvm::Value* data_address = EmitBufferPointer(data_slice, data_shape); TF_ASSIGN_OR_RETURN(BufferAllocation::Slice token_slice, assignment_.GetUniqueSlice(infeed, {1})); - llvm::Value* token_address = EmitTempBufferPointer( + llvm::Value* token_address = EmitBufferPointer( token_slice, ShapeUtil::GetTupleElementShape(infeed->shape(), 1)); llvm_ir::EmitTuple(GetIrArrayFor(infeed), {data_address, token_address}, &b_, module_); @@ -364,9 +368,9 @@ Status IrEmitter::HandleInfeed(HloInstruction* instruction) { // Only the outer tuple buffer's target address is obtained from // GetEmittedValueFor, to handle the case when Infeed is the root // instruction. Target addresses for internal elements can be obtained - // from EmitTempBufferPointer. + // from EmitBufferPointer. llvm::Value* tuple_element_address = - EmitTempBufferPointer(buffer, tuple_element_shape); + EmitBufferPointer(buffer, tuple_element_shape); TF_RETURN_IF_ERROR(EmitXfeedTransfer( XfeedKind::kInfeed, tuple_element_shape, tuple_element_address)); @@ -389,7 +393,7 @@ Status IrEmitter::EmitXfeedTransfer(XfeedKind kind, const Shape& shape, int64 length = ByteSizeOf(shape); if (length <= 0 || length > std::numeric_limits::max()) { return InvalidArgument( - "xfeed (infeed or outfeed) buffer length %lld is outside the valid " + "xfeed (infeed or outfeed) buffer length %d is outside the valid " "size range", length); } @@ -400,13 +404,12 @@ Status IrEmitter::EmitXfeedTransfer(XfeedKind kind, const Shape& shape, llvm::Value * shape_ptr, llvm_ir::EncodeSelfDescribingShapeConstant(shape, &shape_length, &b_)); - // The signature of the acquire infeed buffer function is: - // - // (void*)(int32 length); llvm::Type* int32_type = b_.getInt32Ty(); llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(module_->getContext()); llvm::FunctionType* acquire_type = llvm::FunctionType::get( - i8_ptr_type, {int32_type, i8_ptr_type, int32_type}, + i8_ptr_type, + {/*run_options*/ i8_ptr_type, /*buffer_length*/ int32_type, + /*shape_ptr*/ i8_ptr_type, /*shape_length*/ int32_type}, /*isVarArg=*/false); llvm::Function* acquire_func; @@ -419,11 +422,11 @@ Status IrEmitter::EmitXfeedTransfer(XfeedKind kind, const Shape& shape, } acquire_func->setCallingConv(llvm::CallingConv::C); - // The signature of the release infeed buffer function is: - // - // (void)(int32 length, void* buffer); llvm::FunctionType* release_type = llvm::FunctionType::get( - b_.getVoidTy(), {int32_type, i8_ptr_type, i8_ptr_type, int32_type}, + b_.getVoidTy(), + {/*run_options*/ i8_ptr_type, /*buffer_length*/ int32_type, + /*buffer_ptr*/ i8_ptr_type, /*shape_ptr*/ i8_ptr_type, + /*shape_length*/ int32_type}, /*isVarArg=*/false); llvm::Function* release_func; @@ -440,27 +443,33 @@ Status IrEmitter::EmitXfeedTransfer(XfeedKind kind, const Shape& shape, // of size exactly 'length_32', and the runtime is responsible for // check-failing the process if there is a mismatch, versus passing us back a // buffer that we might overrun. - llvm::Value* acquired_pointer = b_.CreateCall( - acquire_func, - {b_.getInt32(length_32), shape_ptr, b_.getInt32(shape_length)}); + llvm::Value* acquired_pointer = Call( + acquire_func, {GetExecutableRunOptionsArgument(), b_.getInt32(length_32), + shape_ptr, b_.getInt32(shape_length)}); if (kind == XfeedKind::kInfeed) { // Copy to the program buffer address from the acquired buffer. - b_.CreateMemCpy(program_buffer_address, /*DstAlign=*/1, acquired_pointer, - /*SrcAlign=*/1, length_32); + MemCpy(program_buffer_address, /*DstAlign=*/1, acquired_pointer, + /*SrcAlign=*/1, length_32); } else { // Outfeed -- copy from the in-program address to the acquired buffer. - b_.CreateMemCpy(acquired_pointer, /*DstAlign=*/1, program_buffer_address, - /*SrcAlign=*/1, length_32); + MemCpy(acquired_pointer, /*DstAlign=*/1, program_buffer_address, + /*SrcAlign=*/1, length_32); } - b_.CreateCall(release_func, {b_.getInt32(length_32), acquired_pointer, - shape_ptr, b_.getInt32(shape_length)}); + Call(release_func, {GetExecutableRunOptionsArgument(), b_.getInt32(length_32), + acquired_pointer, shape_ptr, b_.getInt32(shape_length)}); return Status::OK(); } Status IrEmitter::HandleOutfeed(HloInstruction* outfeed) { + // Outfeed produces no useful result, but it does return a token[] that can be + // threaded through to other side effecting operations to ensure ordering. In + // the IR emitter we treat this token as a normal u8[] and thus need to insert + // an entry for it in emitted_value_. + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(outfeed)); + HloInstruction* operand = outfeed->operands()[0]; const Shape& operand_shape = operand->shape(); @@ -485,8 +494,150 @@ Status IrEmitter::HandleOutfeed(HloInstruction* outfeed) { } Status IrEmitter::HandleSort(HloInstruction* sort) { - // TODO(b/26783907): Implement sort on CPU. - return Unimplemented("Sort is not implemented on CPU."); + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(sort)); + auto keys = sort->operand(0); + auto values = sort->operand_count() > 1 ? sort->operand(1) : nullptr; + ShapeIndex keys_shape_index({}); + ShapeIndex values_shape_index({}); + if (values != nullptr) { + keys_shape_index = ShapeIndex({0}); + values_shape_index = ShapeIndex({1}); + } + auto keys_destination = GetAllocationSlice(*sort, keys_shape_index); + auto keys_destination_address = + EmitBufferPointer(keys_destination, keys->shape()); + auto values_destination = GetAllocationSlice(*sort, values_shape_index); + llvm::Value* values_destination_address = nullptr; + + // The sort is implemented in-place, therefore we first copy the operand + // buffer to the output buffer if they are not the same. + if (keys_destination != GetAllocationSlice(*keys)) { + int64 primitive_type_size = + ShapeUtil::ByteSizeOfPrimitiveType(keys->shape().element_type()); + auto source_buffer = GetEmittedValueFor(keys); + int64 keys_size = ByteSizeOf(keys->shape()); + MemCpy(keys_destination_address, /*DstAlign=*/primitive_type_size, + source_buffer, + /*SrcAlign=*/primitive_type_size, keys_size); + } + if (values != nullptr) { + values_destination_address = + EmitBufferPointer(values_destination, values->shape()); + if (values_destination != GetAllocationSlice(*values)) { + int64 primitive_type_size = + ShapeUtil::ByteSizeOfPrimitiveType(values->shape().element_type()); + auto source_buffer = GetEmittedValueFor(values); + int64 values_size = ByteSizeOf(values->shape()); + MemCpy(values_destination_address, /*DstAlign=*/primitive_type_size, + source_buffer, + /*SrcAlign=*/primitive_type_size, values_size); + } + } + + // Normalize the shape and the dimension to sort. + Shape normalized_keys_shape = + ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( + keys->shape()); + int64 physical_dimension_to_sort = LayoutUtil::MakeLogicalToPhysical( + keys->shape().layout())[sort->dimensions(0)]; + + int64 sort_dimension_elements = + normalized_keys_shape.dimensions(physical_dimension_to_sort); + int64 higher_dimensions = 1; + for (int64 i = 0; i < physical_dimension_to_sort; ++i) { + higher_dimensions *= normalized_keys_shape.dimensions(i); + } + int64 lower_dimensions = 1; + for (int64 i = ShapeUtil::Rank(normalized_keys_shape) - 1; + i > physical_dimension_to_sort; --i) { + lower_dimensions *= normalized_keys_shape.dimensions(i); + } + + PrimitiveType keys_type = keys->shape().element_type(); + const char* fn_name = nullptr; + llvm::Type* keys_native_type = nullptr; + switch (keys_type) { + case PRED: + fn_name = runtime::kKeyValueSortPREDSymbolName; + keys_native_type = b_.getInt8PtrTy(); + break; + case S8: + fn_name = runtime::kKeyValueSortS8SymbolName; + keys_native_type = b_.getInt8PtrTy(); + break; + case U8: + fn_name = runtime::kKeyValueSortU8SymbolName; + keys_native_type = b_.getInt8PtrTy(); + break; + case S16: + fn_name = runtime::kKeyValueSortS16SymbolName; + keys_native_type = b_.getInt16Ty()->getPointerTo(); + break; + case U16: + fn_name = runtime::kKeyValueSortU16SymbolName; + keys_native_type = b_.getInt16Ty()->getPointerTo(); + break; + case F16: + fn_name = runtime::kKeyValueSortF16SymbolName; + keys_native_type = b_.getHalfTy()->getPointerTo(); + break; + case S32: + fn_name = runtime::kKeyValueSortS32SymbolName; + keys_native_type = b_.getInt32Ty()->getPointerTo(); + break; + case U32: + fn_name = runtime::kKeyValueSortU32SymbolName; + keys_native_type = b_.getInt32Ty()->getPointerTo(); + break; + case F32: + fn_name = runtime::kKeyValueSortF32SymbolName; + keys_native_type = b_.getFloatTy()->getPointerTo(); + break; + case S64: + fn_name = runtime::kKeyValueSortS64SymbolName; + keys_native_type = b_.getInt64Ty()->getPointerTo(); + break; + case U64: + fn_name = runtime::kKeyValueSortU64SymbolName; + keys_native_type = b_.getInt64Ty()->getPointerTo(); + break; + case F64: + fn_name = runtime::kKeyValueSortF64SymbolName; + keys_native_type = b_.getDoubleTy()->getPointerTo(); + break; + default: + return Unimplemented( + "Element type %s not supported in the Sort op on CPU.", + PrimitiveType_Name(keys_type)); + } + + llvm::FunctionType* key_value_sort_type = llvm::FunctionType::get( + b_.getVoidTy(), + {keys_native_type, b_.getInt64Ty(), b_.getInt64Ty(), b_.getInt64Ty(), + b_.getInt8PtrTy(), b_.getInt32Ty()}, + /*isVarArg=*/false); + auto* key_value_sort_func = llvm::cast( + module_->getOrInsertFunction(fn_name, key_value_sort_type)); + key_value_sort_func->setCallingConv(llvm::CallingConv::C); + key_value_sort_func->setDoesNotThrow(); + key_value_sort_func->setOnlyAccessesArgMemory(); + Call(key_value_sort_func, + {PointerCast(keys_destination_address, keys_native_type), + b_.getInt64(higher_dimensions), b_.getInt64(sort_dimension_elements), + b_.getInt64(lower_dimensions), + values != nullptr + ? PointerCast(values_destination_address, b_.getInt8PtrTy()) + : llvm::Constant::getNullValue(b_.getInt8PtrTy()), + b_.getInt32(values != nullptr ? ShapeUtil::ByteSizeOfPrimitiveType( + values->shape().element_type()) + : 0)}); + + if (values != nullptr) { + llvm_ir::EmitTuple(GetIrArrayFor(sort), + {keys_destination_address, values_destination_address}, + &b_, module_); + } + return Status::OK(); } Status IrEmitter::HandleTuple(HloInstruction* tuple) { @@ -501,8 +652,7 @@ Status IrEmitter::HandleTuple(HloInstruction* tuple) { llvm::Value* IrEmitter::EmitElementalMap( const HloMapInstruction& map_instr, - tensorflow::gtl::ArraySlice elemental_operands, - absl::string_view name) { + absl::Span elemental_operands, absl::string_view name) { return EmitThreadLocalCall(*map_instr.to_apply(), elemental_operands, name); } @@ -519,8 +669,8 @@ StatusOr IrEmitter::EmitTargetElementLoopBodyForReduceWindow( llvm_ir::PrimitiveTypeToIrType(operand_element_type, module_), "reduce_window_accumulator_address", &b_, MinimumAlignmentForPrimitiveType(operand_element_type)); - b_.CreateStore(b_.CreateLoad(GetEmittedValueFor(reduce_window->operand(1))), - accumulator_address); + Store(Load(GetEmittedValueFor(reduce_window->operand(1))), + accumulator_address); llvm_ir::ForLoopNest loops(IrName(reduce_window, "inner"), &b_); std::vector window_size; @@ -537,22 +687,38 @@ StatusOr IrEmitter::EmitTargetElementLoopBodyForReduceWindow( llvm::Value* in_bounds_condition = nullptr; for (size_t i = 0; i < index.size(); ++i) { llvm::Value* strided_index = - b_.CreateNSWMul(index[i], b_.getInt64(window.dimensions(i).stride())); + NSWMul(index[i], b_.getInt64(window.dimensions(i).stride())); + input_index[i] = NSWSub( + NSWAdd(strided_index, + NSWMul(window_index[i], + b_.getInt64(window.dimensions(i).window_dilation()))), + b_.getInt64(window.dimensions(i).padding_low())); + + // We need to verify that we are not in the dilated base area. + llvm::Value* dilation_condition = ICmpEQ( + SRem(input_index[i], b_.getInt64(window.dimensions(i).base_dilation())), + b_.getInt64(0)); + if (in_bounds_condition == nullptr) { + in_bounds_condition = dilation_condition; + } else { + in_bounds_condition = And(in_bounds_condition, dilation_condition); + } + + // Apply base dilation to the index. input_index[i] = - b_.CreateNSWSub(b_.CreateNSWAdd(strided_index, window_index[i]), - b_.getInt64(window.dimensions(i).padding_low())); + SDiv(input_index[i], b_.getInt64(window.dimensions(i).base_dilation())); // We need to check if 0 <= input_index[i] < bound, as otherwise we are in // the padding so that we can skip the computation. That is equivalent to // input_index[i] < bound as an *unsigned* comparison, since a negative // value will wrap to a large positive value. - llvm::Value* index_condition = b_.CreateICmpULT( - input_index[i], - b_.getInt64(ShapeUtil::GetDimension(operand->shape(), i))); + llvm::Value* index_condition = + ICmpULT(input_index[i], + b_.getInt64(ShapeUtil::GetDimension(operand->shape(), i))); if (in_bounds_condition == nullptr) { in_bounds_condition = index_condition; } else { - in_bounds_condition = b_.CreateAnd(in_bounds_condition, index_condition); + in_bounds_condition = And(in_bounds_condition, index_condition); } } CHECK(in_bounds_condition != nullptr); @@ -565,12 +731,12 @@ StatusOr IrEmitter::EmitTargetElementLoopBodyForReduceWindow( llvm_ir::IrArray input_array(GetIrArrayFor(operand)); llvm::Value* input_value = input_array.EmitReadArrayElement(input_index, &b_); llvm::Value* result = EmitThreadLocalCall( - *reduce_window->to_apply(), - {b_.CreateLoad(accumulator_address), input_value}, "reducer_function"); - b_.CreateStore(result, accumulator_address); + *reduce_window->to_apply(), {Load(accumulator_address), input_value}, + "reducer_function"); + Store(result, accumulator_address); SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_); - return b_.CreateLoad(accumulator_address); + return Load(accumulator_address); } Status IrEmitter::HandleReduceWindow(HloInstruction* reduce_window) { @@ -579,12 +745,6 @@ Status IrEmitter::HandleReduceWindow(HloInstruction* reduce_window) { /*operands=*/{reduce_window->operand(0)}, /*supported_types=*/{F32, BF16, S32, F16})); - // TODO(b/31410564): Implement dilation for reduce-window. - if (window_util::HasDilation(reduce_window->window())) { - return Unimplemented( - "Dilation for ReduceWindow is not implemented on CPU."); - } - // Pseudo code for reduce window: // // for (coordinates O in the output) @@ -647,7 +807,7 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) { select_and_scatter, /*desc=*/IrName(select_and_scatter, "init"), [this, init_value](const llvm_ir::IrArray::Index& target_index) { llvm::Value* init_value_addr = GetEmittedValueFor(init_value); - return b_.CreateLoad(init_value_addr); + return Load(init_value_addr); })); // Create a loop to iterate over the source array to scatter to the output. @@ -667,7 +827,7 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) { b_.getInt64Ty(), b_.getInt32(rank), "selected_index_address", &b_); llvm::Value* initialized_flag_address = llvm_ir::EmitAllocaAtFunctionEntry( b_.getInt1Ty(), "initialized_flag_address", &b_); - b_.CreateStore(b_.getInt1(false), initialized_flag_address); + Store(b_.getInt1(false), initialized_flag_address); // Create the inner loop to iterate over the window. llvm_ir::ForLoopNest window_loops(IrName(select_and_scatter, "window"), &b_); @@ -685,15 +845,14 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) { llvm_ir::IrArray::Index operand_index(b_.getInt64Ty(), source_index.size()); llvm::Value* in_bounds_condition = b_.getTrue(); for (int64 i = 0; i < rank; ++i) { - llvm::Value* strided_index = b_.CreateNSWMul( - source_index[i], b_.getInt64(window.dimensions(i).stride())); - operand_index[i] = - b_.CreateNSWSub(b_.CreateNSWAdd(strided_index, window_index[i]), - b_.getInt64(window.dimensions(i).padding_low())); - llvm::Value* index_condition = b_.CreateICmpULT( - operand_index[i], - b_.getInt64(ShapeUtil::GetDimension(operand->shape(), i))); - in_bounds_condition = b_.CreateAnd(in_bounds_condition, index_condition); + llvm::Value* strided_index = + NSWMul(source_index[i], b_.getInt64(window.dimensions(i).stride())); + operand_index[i] = NSWSub(NSWAdd(strided_index, window_index[i]), + b_.getInt64(window.dimensions(i).padding_low())); + llvm::Value* index_condition = + ICmpULT(operand_index[i], + b_.getInt64(ShapeUtil::GetDimension(operand->shape(), i))); + in_bounds_condition = And(in_bounds_condition, index_condition); } CHECK(in_bounds_condition != nullptr); @@ -703,7 +862,7 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) { llvm_ir::EmitIfThenElse(in_bounds_condition, "in-bounds", &b_); SetToFirstInsertPoint(if_in_bounds.true_block, &b_); llvm_ir::LlvmIfData if_initialized = llvm_ir::EmitIfThenElse( - b_.CreateLoad(initialized_flag_address), "initialized", &b_); + Load(initialized_flag_address), "initialized", &b_); // If the initialized_flag is false, initialize the selected value and index // with the currently visiting operand. @@ -712,38 +871,37 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) { [&](const llvm_ir::IrArray::Index& operand_index) { for (int64 i = 0; i < rank; ++i) { llvm::Value* selected_index_address_slot = - b_.CreateInBoundsGEP(selected_index_address, {b_.getInt32(i)}); - b_.CreateStore(operand_index[i], selected_index_address_slot); + InBoundsGEP(selected_index_address, {b_.getInt32(i)}); + Store(operand_index[i], selected_index_address_slot); } }; llvm_ir::IrArray operand_array(GetIrArrayFor(operand)); llvm::Value* operand_data = operand_array.EmitReadArrayElement(operand_index, &b_); - b_.CreateStore(operand_data, selected_value_address); + Store(operand_data, selected_value_address); save_operand_index(operand_index); - b_.CreateStore(b_.getInt1(true), initialized_flag_address); + Store(b_.getInt1(true), initialized_flag_address); // If the initialized_flag is true, call the `select` function to potentially // update the selected value and index with the currently visiting operand. SetToFirstInsertPoint(if_initialized.true_block, &b_); llvm::Value* operand_address = operand_array.EmitArrayElementAddress(operand_index, &b_); - llvm::Value* operand_element = b_.CreateLoad(operand_address); + llvm::Value* operand_element = Load(operand_address); llvm::Value* result = EmitThreadLocalCall( *select_and_scatter->select(), - {b_.CreateLoad(selected_value_address), operand_element}, - "select_function"); + {Load(selected_value_address), operand_element}, "select_function"); // If the 'select' function returns false, update the selected value and the // index to the currently visiting operand. - llvm::Value* cond = b_.CreateICmpNE( + llvm::Value* cond = ICmpNE( result, llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0), "boolean_predicate"); llvm_ir::LlvmIfData if_select_lhs = llvm_ir::EmitIfThenElse(cond, "if-select-lhs", &b_); SetToFirstInsertPoint(if_select_lhs.false_block, &b_); - b_.CreateStore(b_.CreateLoad(operand_address), selected_value_address); + Store(Load(operand_address), selected_value_address); save_operand_index(operand_index); // After iterating over the window elements, scatter the source element to @@ -754,8 +912,8 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) { llvm_ir::IrArray::Index selected_index(source_index.GetType()); for (int64 i = 0; i < rank; ++i) { llvm::Value* selected_index_address_slot = - b_.CreateInBoundsGEP(selected_index_address, {b_.getInt32(i)}); - selected_index.push_back(b_.CreateLoad(selected_index_address_slot)); + InBoundsGEP(selected_index_address, {b_.getInt32(i)}); + selected_index.push_back(Load(selected_index_address_slot)); } llvm_ir::IrArray source_array(GetIrArrayFor(source)); llvm::Value* source_value = @@ -837,7 +995,7 @@ StatusOr IrEmitter::EmitTargetElementLoopBodyForConvolution( lhs_llvm_type, "convolution_sum_address", &b_, MinimumAlignmentForPrimitiveType(lhs_element_type)); llvm::Value* constant_zero = llvm::Constant::getNullValue(lhs_llvm_type); - b_.CreateStore(constant_zero, sum_address); + Store(constant_zero, sum_address); llvm_ir::ForLoopNest loops(IrName(convolution, "inner"), &b_); std::vector kernel_spatial(num_spatial_dims); @@ -864,11 +1022,11 @@ StatusOr IrEmitter::EmitTargetElementLoopBodyForConvolution( llvm::Value* kernel_index, const WindowDimension& window_dim) { llvm::Value* strided_index = - b_.CreateNSWMul(output_index, b_.getInt64(window_dim.stride())); - llvm::Value* dilated_kernel_index = b_.CreateNSWMul( - kernel_index, b_.getInt64(window_dim.window_dilation())); - return b_.CreateNSWSub(b_.CreateNSWAdd(strided_index, dilated_kernel_index), - b_.getInt64(window_dim.padding_low())); + NSWMul(output_index, b_.getInt64(window_dim.stride())); + llvm::Value* dilated_kernel_index = + NSWMul(kernel_index, b_.getInt64(window_dim.window_dilation())); + return NSWSub(NSWAdd(strided_index, dilated_kernel_index), + b_.getInt64(window_dim.padding_low())); }; std::vector input_spatial(num_spatial_dims); for (int i = 0; i < num_spatial_dims; ++i) { @@ -885,9 +1043,8 @@ StatusOr IrEmitter::EmitTargetElementLoopBodyForConvolution( // Also need to check that the input coordinates are not in one of the // holes created by base dilation. const auto not_in_hole = [&](llvm::Value* input_index, int64 base_dilation) { - llvm::Value* remainder = - b_.CreateSRem(input_index, b_.getInt64(base_dilation)); - return b_.CreateICmpEQ(remainder, b_.getInt64(0)); + llvm::Value* remainder = SRem(input_index, b_.getInt64(base_dilation)); + return ICmpEQ(remainder, b_.getInt64(0)); }; llvm::Value* in_bounds_condition = b_.getInt1(true); @@ -895,17 +1052,17 @@ StatusOr IrEmitter::EmitTargetElementLoopBodyForConvolution( llvm::ConstantInt* input_bound = b_.getInt64(window_util::DilatedBound( lhs->shape().dimensions(dnums.input_spatial_dimensions(i)), window.dimensions(i).base_dilation())); - llvm::Value* dim_in_bound = b_.CreateICmpULT(input_spatial[i], input_bound); + llvm::Value* dim_in_bound = ICmpULT(input_spatial[i], input_bound); llvm::Value* dim_not_in_hole = not_in_hole(input_spatial[i], window.dimensions(i).base_dilation()); - llvm::Value* dim_ok = b_.CreateAnd(dim_in_bound, dim_not_in_hole); - in_bounds_condition = b_.CreateAnd(in_bounds_condition, dim_ok); + llvm::Value* dim_ok = And(dim_in_bound, dim_not_in_hole); + in_bounds_condition = And(in_bounds_condition, dim_ok); } // Now we need to map the dilated base coordinates back to the actual // data indices on the lhs. const auto undilate = [&](llvm::Value* input_index, int64 base_dilation) { - return b_.CreateSDiv(input_index, b_.getInt64(base_dilation)); + return SDiv(input_index, b_.getInt64(base_dilation)); }; for (int i = 0; i < num_spatial_dims; ++i) { input_spatial[i] = @@ -930,8 +1087,8 @@ StatusOr IrEmitter::EmitTargetElementLoopBodyForConvolution( for (int i = 0; i < num_spatial_dims; ++i) { kernel_index[dnums.kernel_spatial_dimensions(i)] = window.dimensions(i).window_reversal() - ? b_.CreateNSWSub(b_.getInt64(window.dimensions(i).size() - 1), - kernel_spatial[i]) + ? NSWSub(b_.getInt64(window.dimensions(i).size() - 1), + kernel_spatial[i]) : kernel_spatial[i]; } @@ -940,13 +1097,13 @@ StatusOr IrEmitter::EmitTargetElementLoopBodyForConvolution( llvm_ir::IrArray input_array(GetIrArrayFor(lhs)); llvm::Value* product = - b_.CreateFMul(input_array.EmitReadArrayElement(input_index, &b_), - kernel_array.EmitReadArrayElement(kernel_index, &b_)); - llvm::Value* sum = b_.CreateFAdd(b_.CreateLoad(sum_address), product); - b_.CreateStore(sum, sum_address); + FMul(input_array.EmitReadArrayElement(input_index, &b_), + kernel_array.EmitReadArrayElement(kernel_index, &b_)); + llvm::Value* sum = FAdd(Load(sum_address), product); + Store(sum, sum_address); SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_); - return b_.CreateLoad(sum_address); + return Load(sum_address); } Status IrEmitter::HandleConvolution(HloInstruction* convolution) { @@ -1072,34 +1229,32 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) { conv_func->setCallingConv(llvm::CallingConv::C); conv_func->setDoesNotThrow(); conv_func->setOnlyAccessesArgMemory(); - b_.CreateCall( - conv_func, - { - GetExecutableRunOptionsArgument(), - b_.CreateBitCast(GetEmittedValueFor(convolution), ir_ptr_type), - b_.CreateBitCast(lhs_address, ir_ptr_type), - b_.CreateBitCast(rhs_address, ir_ptr_type), - b_.getInt64(input_batch), - b_.getInt64(input_rows), - b_.getInt64(input_cols), - b_.getInt64(input_channels), - b_.getInt64(kernel_rows), - b_.getInt64(kernel_cols), - b_.getInt64(kernel_channels), - b_.getInt64(kernel_filters), - b_.getInt64(output_rows), - b_.getInt64(output_cols), - b_.getInt64(row_stride), - b_.getInt64(col_stride), - b_.getInt64(padding_top), - b_.getInt64(padding_bottom), - b_.getInt64(padding_left), - b_.getInt64(padding_right), - b_.getInt64(lhs_row_dilation), - b_.getInt64(lhs_col_dilation), - b_.getInt64(rhs_row_dilation), - b_.getInt64(rhs_col_dilation), - }); + Call(conv_func, { + GetExecutableRunOptionsArgument(), + BitCast(GetEmittedValueFor(convolution), ir_ptr_type), + BitCast(lhs_address, ir_ptr_type), + BitCast(rhs_address, ir_ptr_type), + b_.getInt64(input_batch), + b_.getInt64(input_rows), + b_.getInt64(input_cols), + b_.getInt64(input_channels), + b_.getInt64(kernel_rows), + b_.getInt64(kernel_cols), + b_.getInt64(kernel_channels), + b_.getInt64(kernel_filters), + b_.getInt64(output_rows), + b_.getInt64(output_cols), + b_.getInt64(row_stride), + b_.getInt64(col_stride), + b_.getInt64(padding_top), + b_.getInt64(padding_bottom), + b_.getInt64(padding_left), + b_.getInt64(padding_right), + b_.getInt64(lhs_row_dilation), + b_.getInt64(lhs_col_dilation), + b_.getInt64(rhs_row_dilation), + b_.getInt64(rhs_col_dilation), + }); return Status::OK(); } @@ -1159,15 +1314,14 @@ Status IrEmitter::HandleFft(HloInstruction* fft) { fft_func->setDoesNotThrow(); fft_func->setOnlyAccessesInaccessibleMemOrArgMem(); const int fft_rank = fft_length.size(); - b_.CreateCall( - fft_func, - {GetExecutableRunOptionsArgument(), - b_.CreateBitCast(GetEmittedValueFor(fft), int8_ptr_type), - b_.CreateBitCast(operand_address, int8_ptr_type), - b_.getInt32(fft->fft_type()), b_.getInt32(fft_rank), - b_.getInt64(input_batch), b_.getInt64(fft_rank > 0 ? fft_length[0] : 0), - b_.getInt64(fft_rank > 1 ? fft_length[1] : 0), - b_.getInt64(fft_rank > 2 ? fft_length[2] : 0)}); + Call(fft_func, + {GetExecutableRunOptionsArgument(), + BitCast(GetEmittedValueFor(fft), int8_ptr_type), + BitCast(operand_address, int8_ptr_type), b_.getInt32(fft->fft_type()), + b_.getInt32(fft_rank), b_.getInt64(input_batch), + b_.getInt64(fft_rank > 0 ? fft_length[0] : 0), + b_.getInt64(fft_rank > 1 ? fft_length[1] : 0), + b_.getInt64(fft_rank > 2 ? fft_length[2] : 0)}); return Status::OK(); } @@ -1203,11 +1357,11 @@ Status IrEmitter::HandleCrossReplicaSum(HloInstruction* crs) { const Shape& operand_shape = crs->operand(i)->shape(); CHECK(ShapeUtil::IsArray(operand_shape)) << "Operands to cross-replica-sum must be arrays: " << crs->ToString(); - operand_ptrs.push_back(EmitTempBufferPointer(out_slice, operand_shape)); + operand_ptrs.push_back(EmitBufferPointer(out_slice, operand_shape)); // TODO(b/63762267): Be more aggressive about specifying alignment. - b_.CreateMemCpy(operand_ptrs.back(), /*DstAlign=*/1, in_ptr, - /*SrcAlign=*/1, ShapeUtil::ByteSizeOf(operand_shape)); + MemCpy(operand_ptrs.back(), /*DstAlign=*/1, in_ptr, + /*SrcAlign=*/1, ShapeUtil::ByteSizeOf(operand_shape)); } llvm_ir::EmitTuple(GetIrArrayFor(crs), operand_ptrs, &b_, module_); return Status::OK(); @@ -1255,10 +1409,10 @@ static bool ReductionPreservesLayout(const HloInstruction& reduce) { // // So if we reduce f32[A,B,C,D] on dimensions 1 and 2, this map contains // [0->0, 3->1]. - gtl::FlatMap unreduced_dim_map; + absl::flat_hash_map unreduced_dim_map; - gtl::FlatSet reduced_dims(reduce.dimensions().begin(), - reduce.dimensions().end()); + absl::flat_hash_set reduced_dims(reduce.dimensions().begin(), + reduce.dimensions().end()); const Shape& operand_shape = reduce.operand(0)->shape(); const Shape& result_shape = reduce.shape(); @@ -1457,7 +1611,7 @@ IrEmitter::EmitInnerLoopForVectorizedReduction( const ReductionGenerator& reduction_generator, const llvm_ir::IrArray::Index& output_index, const ShardedVectorType& accumulator_type, HloInstruction* init_value, - HloInstruction* arg, gtl::ArraySlice dimensions, + HloInstruction* arg, absl::Span dimensions, unsigned element_alignment) { ShardedVector accumulator; accumulator.reserve(accumulator_type.size()); @@ -1466,19 +1620,19 @@ IrEmitter::EmitInnerLoopForVectorizedReduction( accumulator_shard_type, "accumulator", &b_, 0)); } - llvm::Value* init_value_ssa = b_.CreateLoad(GetEmittedValueFor(init_value)); + llvm::Value* init_value_ssa = Load(GetEmittedValueFor(init_value)); for (llvm::Value* accumulator_shard : accumulator) { llvm::Value* initial_value; auto shard_type = accumulator_shard->getType()->getPointerElementType(); if (auto vector_type = llvm::dyn_cast(shard_type)) { initial_value = - b_.CreateVectorSplat(vector_type->getNumElements(), init_value_ssa); + VectorSplat(vector_type->getNumElements(), init_value_ssa); } else { initial_value = init_value_ssa; } - b_.CreateAlignedStore(initial_value, accumulator_shard, element_alignment); + AlignedStore(initial_value, accumulator_shard, element_alignment); } llvm_ir::ForLoopNest reduction_loop_nest(IrName(arg, "vectorized_inner"), @@ -1500,24 +1654,24 @@ IrEmitter::EmitInnerLoopForVectorizedReduction( } CHECK(output_index.end() == it); - llvm::Value* input_address = b_.CreateBitCast( + llvm::Value* input_address = BitCast( arg_array.EmitArrayElementAddress(input_index, &b_), b_.getInt8PtrTy()); for (int i = 0; i < accumulator.size(); i++) { auto input_address_typed = - b_.CreateBitCast(input_address, accumulator[i]->getType()); + BitCast(input_address, accumulator[i]->getType()); auto current_accumulator_value = - b_.CreateAlignedLoad(accumulator[i], element_alignment); - auto addend = b_.CreateAlignedLoad(input_address_typed, element_alignment); + AlignedLoad(accumulator[i], element_alignment); + auto addend = AlignedLoad(input_address_typed, element_alignment); arg_array.AnnotateLoadStoreInstructionWithMetadata(addend); auto reduced_result = reduction_generator(&b_, current_accumulator_value, addend); - b_.CreateAlignedStore(reduced_result, accumulator[i], element_alignment); + AlignedStore(reduced_result, accumulator[i], element_alignment); if (i != (accumulator.size() - 1)) { - input_address = b_.CreateConstInBoundsGEP1_32(reduced_result->getType(), - input_address_typed, 1); + input_address = ConstInBoundsGEP1_32(reduced_result->getType(), + input_address_typed, 1); } } @@ -1526,8 +1680,7 @@ IrEmitter::EmitInnerLoopForVectorizedReduction( ShardedVector result_ssa; result_ssa.reserve(accumulator.size()); for (auto accumulator_shard : accumulator) { - result_ssa.push_back( - b_.CreateAlignedLoad(accumulator_shard, element_alignment)); + result_ssa.push_back(AlignedLoad(accumulator_shard, element_alignment)); } return result_ssa; } @@ -1536,25 +1689,25 @@ void IrEmitter::EmitShardedVectorStore( llvm::Value* store_address, const std::vector& value_to_store, const int alignment, const llvm_ir::IrArray& containing_array) { for (int i = 0; i < value_to_store.size(); i++) { - auto store_address_typed = b_.CreateBitCast( - store_address, - llvm::PointerType::getUnqual(value_to_store[i]->getType())); + auto store_address_typed = + BitCast(store_address, + llvm::PointerType::getUnqual(value_to_store[i]->getType())); - auto store_instruction = b_.CreateAlignedStore( - value_to_store[i], store_address_typed, alignment); + auto store_instruction = + AlignedStore(value_to_store[i], store_address_typed, alignment); containing_array.AnnotateLoadStoreInstructionWithMetadata( store_instruction); if (i != (value_to_store.size() - 1)) { - store_address = b_.CreateConstInBoundsGEP1_32( - value_to_store[i]->getType(), store_address_typed, 1); + store_address = ConstInBoundsGEP1_32(value_to_store[i]->getType(), + store_address_typed, 1); } } } StatusOr IrEmitter::EmitVectorizedReduce( HloInstruction* reduce, HloInstruction* arg, HloInstruction* init_value, - gtl::ArraySlice dimensions, HloComputation* function, + absl::Span dimensions, HloComputation* function, string* failure_reason) { if (!ReductionPreservesLayout(*reduce)) { return false; @@ -1620,9 +1773,8 @@ StatusOr IrEmitter::EmitVectorizedReduce( int64 dimension = LayoutUtil::Minor(reduce->shape().layout(), i); int64 start_index = 0; int64 end_index = reduce->shape().dimensions(dimension); - std::unique_ptr loop = - loop_nest.AddLoop(start_index, end_index, - tensorflow::strings::Printf("dim.%lld", dimension)); + std::unique_ptr loop = loop_nest.AddLoop( + start_index, end_index, absl::StrFormat("dim.%d", dimension)); array_index[dimension] = loop->GetIndVarValue(); } @@ -1641,9 +1793,9 @@ StatusOr IrEmitter::EmitVectorizedReduce( int64 start_index = 0; int64 end_index = (innermost_dimension_size / vectorization_factor) * vectorization_factor; - std::unique_ptr loop = loop_nest.AddLoop( - start_index, end_index, vectorization_factor, - tensorflow::strings::Printf("dim.%lld", innermost_dimension)); + std::unique_ptr loop = + loop_nest.AddLoop(start_index, end_index, vectorization_factor, + absl::StrFormat("dim.%d", innermost_dimension)); array_index[innermost_dimension] = loop->GetIndVarValue(); SetToFirstInsertPoint(loop->GetBodyBasicBlock(), &b_); @@ -1705,7 +1857,7 @@ StatusOr IrEmitter::EmitTargetElementLoopBodyForReduce( HloReduceInstruction* reduce, const llvm_ir::IrArray::Index& index) { const HloInstruction* arg = reduce->mutable_operand(0); const HloInstruction* init_value = reduce->mutable_operand(1); - gtl::ArraySlice dimensions(reduce->dimensions()); + absl::Span dimensions(reduce->dimensions()); // Initialize an accumulator with init_value. PrimitiveType accumulator_type = reduce->shape().element_type(); @@ -1713,8 +1865,8 @@ StatusOr IrEmitter::EmitTargetElementLoopBodyForReduce( llvm_ir::PrimitiveTypeToIrType(accumulator_type, module_), "accumulator", &b_, MinimumAlignmentForPrimitiveType(accumulator_type)); llvm::Value* init_value_addr = GetEmittedValueFor(init_value); - llvm::Value* load_init_value = b_.CreateLoad(init_value_addr); - b_.CreateStore(load_init_value, accumulator_addr); + llvm::Value* load_init_value = Load(init_value_addr); + Store(load_init_value, accumulator_addr); // The enclosing loops go over all the target elements. Now we have to compute // the actual target element. For this, we build a new loop nest to iterate @@ -1747,12 +1899,12 @@ StatusOr IrEmitter::EmitTargetElementLoopBodyForReduce( // Apply the reduction function to the loaded value. llvm::Value* input_element = arg_array.EmitReadArrayElement(input_index, &b_); llvm::Value* result = EmitThreadLocalCall( - *reduce->to_apply(), {b_.CreateLoad(accumulator_addr), input_element}, + *reduce->to_apply(), {Load(accumulator_addr), input_element}, "reduce_function"); - b_.CreateStore(result, accumulator_addr); + Store(result, accumulator_addr); SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_); - return b_.CreateLoad(accumulator_addr); + return Load(accumulator_addr); } Status IrEmitter::HandleReduce(HloInstruction* reduce) { @@ -1762,7 +1914,7 @@ Status IrEmitter::HandleReduce(HloInstruction* reduce) { } auto arg = reduce->mutable_operand(0); auto init_value = reduce->mutable_operand(1); - gtl::ArraySlice dimensions(reduce->dimensions()); + absl::Span dimensions(reduce->dimensions()); HloComputation* function = reduce->to_apply(); if (!options::VectorizedReduceDisabled(hlo_module_config_)) { string vectorization_failure_reason; @@ -1836,7 +1988,7 @@ Status IrEmitter::HandleSlice(HloInstruction* slice) { // // * Implement the memcpy within the innermost loop. - gtl::FlatSet inner_dims; + absl::flat_hash_set inner_dims; for (int64 dim : LayoutUtil::MinorToMajor(layout)) { if (operand->shape().dimensions(dim) != slice->shape().dimensions(dim)) { break; @@ -1990,7 +2142,7 @@ Status IrEmitter::HandlePad(HloInstruction* pad) { [this, pad](const llvm_ir::IrArray::Index& target_index) { const HloInstruction* padding_value = pad->operand(1); llvm::Value* padding_value_addr = GetEmittedValueFor(padding_value); - return b_.CreateLoad(padding_value_addr); + return Load(padding_value_addr); })); // Create a loop to iterate over the operand elements and update the output @@ -2012,10 +2164,10 @@ Status IrEmitter::HandlePad(HloInstruction* pad) { const PaddingConfig& padding_config = pad->padding_config(); llvm_ir::IrArray::Index output_index(operand_index.GetType()); for (size_t i = 0; i < operand_index.size(); ++i) { - llvm::Value* offset = b_.CreateMul( - operand_index[i], - b_.getInt64(padding_config.dimensions(i).interior_padding() + 1)); - llvm::Value* index = b_.CreateAdd( + llvm::Value* offset = + Mul(operand_index[i], + b_.getInt64(padding_config.dimensions(i).interior_padding() + 1)); + llvm::Value* index = Add( offset, b_.getInt64(padding_config.dimensions(i).edge_padding_low())); output_index.push_back(index); } @@ -2102,7 +2254,7 @@ Status IrEmitter::HandleCall(HloInstruction* call) { {}, &b_, computation->name(), /*return_value_buffer=*/emitted_value_[call], /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(), - /*temp_buffers_arg=*/GetTempBuffersArgument(), + /*buffer_table_arg=*/GetBufferTableArgument(), /*profile_counters_arg=*/GetProfileCountersArgument()); HloInstruction* root = computation->root_instruction(); @@ -2117,7 +2269,7 @@ Status IrEmitter::HandleCall(HloInstruction* call) { } Status IrEmitter::HandleCustomCall(HloInstruction* custom_call) { - gtl::ArraySlice operands(custom_call->operands()); + absl::Span operands(custom_call->operands()); absl::string_view custom_call_target(custom_call->custom_call_target()); llvm::Type* i8_ptr_type = b_.getInt8PtrTy(); llvm::AllocaInst* operands_alloca = @@ -2126,10 +2278,10 @@ Status IrEmitter::HandleCustomCall(HloInstruction* custom_call) { for (size_t i = 0; i < operands.size(); ++i) { const HloInstruction* operand = operands[i]; llvm::Value* operand_as_i8ptr = - b_.CreatePointerCast(GetEmittedValueFor(operand), i8_ptr_type); + PointerCast(GetEmittedValueFor(operand), i8_ptr_type); llvm::Value* slot_in_operands_alloca = - b_.CreateInBoundsGEP(operands_alloca, {b_.getInt64(i)}); - b_.CreateStore(operand_as_i8ptr, slot_in_operands_alloca); + InBoundsGEP(operands_alloca, {b_.getInt64(i)}); + Store(operand_as_i8ptr, slot_in_operands_alloca); } auto* custom_call_ir_function = llvm::cast(module_->getOrInsertFunction( @@ -2141,9 +2293,9 @@ Status IrEmitter::HandleCustomCall(HloInstruction* custom_call) { TF_RETURN_IF_ERROR(EmitTargetAddressForOp(custom_call)); auto* output_address_arg = - b_.CreatePointerCast(GetEmittedValueFor(custom_call), i8_ptr_type); + PointerCast(GetEmittedValueFor(custom_call), i8_ptr_type); - b_.CreateCall(custom_call_ir_function, {output_address_arg, operands_alloca}); + Call(custom_call_ir_function, {output_address_arg, operands_alloca}); return Status::OK(); } @@ -2170,8 +2322,8 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) { return InternalError( "instruction %s %s does not share slice with " "instruction %s %s", - a->ToString().c_str(), slice_a.ToString().c_str(), - b->ToString().c_str(), slice_b.ToString().c_str()); + a->ToString(), slice_a.ToString(), b->ToString(), + slice_b.ToString()); } return Status::OK(); }; @@ -2202,15 +2354,14 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) { llvm::BasicBlock* header_bb = llvm::BasicBlock::Create( module_->getContext(), AsStringRef(IrName(xla_while, "header")), compute_function_->function()); - b_.CreateBr(header_bb); + Br(header_bb); b_.SetInsertPoint(header_bb); // Calls the condition function to determine whether to proceed with the // body. It must return a bool, so use the scalar call form. EmitGlobalCall(*xla_while->while_condition(), IrName(xla_while, "cond")); - llvm::Value* while_predicate = b_.CreateICmpNE( - b_.CreateLoad( - GetBufferForGlobalCallReturnValue(*xla_while->while_condition())), + llvm::Value* while_predicate = ICmpNE( + Load(GetBufferForGlobalCallReturnValue(*xla_while->while_condition())), llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0)); // Branches to the body or to the while exit depending on the condition. @@ -2219,7 +2370,7 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) { compute_function_->function()); llvm::BasicBlock* exit_bb = llvm::BasicBlock::Create( module_->getContext(), AsStringRef(IrName(xla_while, "exit"))); - b_.CreateCondBr(while_predicate, body_bb, exit_bb); + CondBr(while_predicate, body_bb, exit_bb); // Calls the body function from the body block. b_.SetInsertPoint(body_bb); @@ -2228,7 +2379,7 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) { EmitGlobalCall(*xla_while->while_body(), IrName(xla_while, "body")); // Finishes with a branch back to the header. - b_.CreateBr(header_bb); + Br(header_bb); // Adds the exit block to the function and sets the insert point there. compute_function_->function()->getBasicBlockList().push_back(exit_bb); @@ -2238,7 +2389,7 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) { } StatusOr IrEmitter::EmitFastConcatenate( - HloInstruction* concatenate, gtl::ArraySlice operands, + HloInstruction* concatenate, absl::Span operands, string* failure_reason) { if (ShouldEmitParallelLoopFor(*concatenate)) { *failure_reason = @@ -2275,7 +2426,6 @@ StatusOr IrEmitter::EmitFastConcatenate( output_min2maj.end()); llvm::Type* i8_ptr_type = b_.getInt8PtrTy(); - llvm::Type* i8_type = b_.getInt8Ty(); TF_RETURN_IF_ERROR(EmitTargetAddressForOp(concatenate)); llvm_ir::IrArray target_array = GetIrArrayFor(concatenate); @@ -2298,9 +2448,9 @@ StatusOr IrEmitter::EmitFastConcatenate( // Contiguous subregions from each operand to the concatenate contribute to a // contiguous subregion in the target buffer starting at target_region_begin. llvm::Value* target_region_begin = - b_.CreateBitCast(target_array.EmitArrayElementAddress( - outer_dims_index, &b_, "target_region"), - i8_ptr_type); + BitCast(target_array.EmitArrayElementAddress(outer_dims_index, &b_, + "target_region"), + i8_ptr_type); int64 byte_offset_into_target_region = 0; int64 inner_dims_product = @@ -2314,13 +2464,12 @@ StatusOr IrEmitter::EmitFastConcatenate( for (HloInstruction* operand : operands) { const Shape& input_shape = operand->shape(); llvm_ir::IrArray source_array = GetIrArrayFor(operand); - llvm::Value* copy_source_address = b_.CreateBitCast( + llvm::Value* copy_source_address = BitCast( source_array.EmitArrayElementAddress(outer_dims_index, &b_, "src_addr"), i8_ptr_type); llvm::Value* copy_target_address = - b_.CreateGEP(i8_type, target_region_begin, - b_.getInt64(byte_offset_into_target_region)); + GEP(target_region_begin, b_.getInt64(byte_offset_into_target_region)); EmitTransferElements( copy_target_address, copy_source_address, @@ -2352,15 +2501,15 @@ void IrEmitter::EmitTransferElements(llvm::Value* target, llvm::Value* source, llvm_ir::PrimitiveTypeToIrType(primitive_type, module_)); if (element_count == 1) { - auto* load_instruction = b_.CreateAlignedLoad( - b_.CreateBitCast(source, primitive_ptr_type), element_alignment); + auto* load_instruction = + AlignedLoad(BitCast(source, primitive_ptr_type), element_alignment); source_array.AnnotateLoadStoreInstructionWithMetadata(load_instruction); - auto* store_instruction = b_.CreateAlignedStore( - load_instruction, b_.CreateBitCast(target, primitive_ptr_type), - element_alignment); + auto* store_instruction = + AlignedStore(load_instruction, BitCast(target, primitive_ptr_type), + element_alignment); target_array.AnnotateLoadStoreInstructionWithMetadata(store_instruction); } else { - auto* memcpy_instruction = b_.CreateMemCpy( + auto* memcpy_instruction = MemCpy( target, /*DstAlign=*/element_alignment, source, /*SrcAlign=*/element_alignment, element_count * primitive_type_size); @@ -2376,7 +2525,7 @@ void IrEmitter::EmitTransferElements(llvm::Value* target, llvm::Value* source, } Status IrEmitter::HandleConcatenate(HloInstruction* concatenate) { - gtl::ArraySlice operands(concatenate->operands()); + absl::Span operands(concatenate->operands()); string failure_reason; TF_ASSIGN_OR_RETURN( bool successful, @@ -2422,9 +2571,9 @@ Status IrEmitter::HandleConditional(HloInstruction* conditional) { // cond_result = true_computation(true_operand) // else // cond_result = false_computation(false_operand) - llvm::LoadInst* pred_value = b_.CreateLoad( - GetIrArrayFor(pred).GetBasePointer(), "load_predicate_value"); - llvm::Value* pred_cond = b_.CreateICmpNE( + llvm::LoadInst* pred_value = + Load(GetIrArrayFor(pred).GetBasePointer(), "load_predicate_value"); + llvm::Value* pred_cond = ICmpNE( pred_value, llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0), "boolean_predicate"); @@ -2450,11 +2599,6 @@ Status IrEmitter::HandleAfterAll(HloInstruction* gen_token) { return Status::OK(); } -Status IrEmitter::HandleIota(HloInstruction* iota) { - // TODO(b/64798317): implement iota on CPU. - return Unimplemented("Iota is not implemented on CPU."); -} - Status IrEmitter::HandleRng(HloInstruction* rng) { ElementalIrEmitter::HloToElementGeneratorMap operand_to_generator; for (const HloInstruction* operand : rng->operands()) { @@ -2511,8 +2655,8 @@ llvm::Value* IrEmitter::GetProfileCounterCommon( int64 prof_counter_idx = it->second; string counter_name = IrName("prof_counter", hlo.name()); - return b_.CreateGEP(GetProfileCountersArgument(), - b_.getInt64(prof_counter_idx), AsStringRef(counter_name)); + return GEP(GetProfileCountersArgument(), b_.getInt64(prof_counter_idx), + AsStringRef(counter_name)); } void IrEmitter::ProfilingState::UpdateProfileCounter(llvm::IRBuilder<>* b, @@ -2630,15 +2774,15 @@ llvm::Value* IrEmitter::GetProfileCountersArgument() { return compute_function_->profile_counters_arg(); } -llvm::Value* IrEmitter::GetTempBuffersArgument() { - return compute_function_->temp_buffers_arg(); +llvm::Value* IrEmitter::GetBufferTableArgument() { + return compute_function_->buffer_table_arg(); } llvm::Value* IrEmitter::GetExecutableRunOptionsArgument() { return compute_function_->exec_run_options_arg(); } -llvm::Value* IrEmitter::EmitThreadLocalTempBufferPointer( +llvm::Value* IrEmitter::EmitThreadLocalBufferPointer( const BufferAllocation::Slice& slice, const Shape& target_shape) { const BufferAllocation& allocation = *slice.allocation(); llvm::Value* tempbuf_address = [&]() -> llvm::Value* { @@ -2666,8 +2810,7 @@ llvm::Value* IrEmitter::EmitThreadLocalTempBufferPointer( llvm::Value* params = compute_function_->parameters_arg(); llvm::Value* param_address_offset = llvm_ir::EmitBufferIndexingGEP(params, param_number, &b_); - llvm::LoadInst* param_address_untyped = - b_.CreateLoad(param_address_offset); + llvm::LoadInst* param_address_untyped = Load(param_address_offset); if (!ShapeUtil::IsOpaque(target_shape)) { AttachAlignmentMetadataForLoad(param_address_untyped, target_shape); @@ -2695,16 +2838,15 @@ llvm::Value* IrEmitter::EmitThreadLocalTempBufferPointer( } return buf_it->second; }(); - return b_.CreateBitCast(tempbuf_address, - IrShapeType(target_shape)->getPointerTo()); + return BitCast(tempbuf_address, IrShapeType(target_shape)->getPointerTo()); } -llvm::Value* IrEmitter::EmitGlobalTempBufferPointer( +llvm::Value* IrEmitter::EmitGlobalBufferPointer( const BufferAllocation::Slice& slice, const Shape& target_shape) { const BufferAllocation& allocation = *slice.allocation(); llvm::Value* tempbuf_address_ptr = llvm_ir::EmitBufferIndexingGEP( - GetTempBuffersArgument(), slice.index(), &b_); - llvm::LoadInst* tempbuf_address_base = b_.CreateLoad(tempbuf_address_ptr); + GetBufferTableArgument(), slice.index(), &b_); + llvm::LoadInst* tempbuf_address_base = Load(tempbuf_address_ptr); if (hlo_module_config_.debug_options() .xla_llvm_enable_invariant_load_metadata()) { tempbuf_address_base->setMetadata( @@ -2718,20 +2860,20 @@ llvm::Value* IrEmitter::EmitGlobalTempBufferPointer( if (slice.offset() > 0) { // Adjust the address to account for the slice offset. tempbuf_address_untyped = - b_.CreateInBoundsGEP(tempbuf_address_base, b_.getInt64(slice.offset())); + InBoundsGEP(tempbuf_address_base, b_.getInt64(slice.offset())); } - return b_.CreateBitCast(tempbuf_address_untyped, - IrShapeType(target_shape)->getPointerTo()); + return BitCast(tempbuf_address_untyped, + IrShapeType(target_shape)->getPointerTo()); } -llvm::Value* IrEmitter::EmitTempBufferPointer( - const BufferAllocation::Slice& slice, const Shape& target_shape) { +llvm::Value* IrEmitter::EmitBufferPointer(const BufferAllocation::Slice& slice, + const Shape& target_shape) { if (slice.allocation()->is_thread_local()) { - return EmitThreadLocalTempBufferPointer(slice, target_shape); + return EmitThreadLocalBufferPointer(slice, target_shape); } else if (slice.allocation()->is_constant()) { return FindOrDie(constant_buffer_to_global_, slice.allocation()->index()); } else { - return EmitGlobalTempBufferPointer(slice, target_shape); + return EmitGlobalBufferPointer(slice, target_shape); } } @@ -2739,7 +2881,7 @@ Status IrEmitter::EmitTargetAddressForOp(const HloInstruction* op) { const Shape& target_shape = op->shape(); TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice, assignment_.GetUniqueTopLevelSlice(op)); - llvm::Value* addr = EmitTempBufferPointer(slice, target_shape); + llvm::Value* addr = EmitBufferPointer(slice, target_shape); addr->setName(AsStringRef(IrName(op))); emitted_value_[op] = addr; return Status::OK(); @@ -2768,8 +2910,7 @@ Status IrEmitter::EmitTargetElementLoop( TF_ASSIGN_OR_RETURN(BufferAllocation::Slice slice, assignment_.GetUniqueSlice(target_op, {i})); const Shape& element_shape = ShapeUtil::GetSubshape(target_shape, {i}); - llvm::Value* op_target_address = - EmitTempBufferPointer(slice, element_shape); + llvm::Value* op_target_address = EmitBufferPointer(slice, element_shape); output_arrays.push_back( llvm_ir::IrArray(op_target_address, element_shape)); } @@ -2807,15 +2948,15 @@ Status IrEmitter::EmitMemcpy(const HloInstruction& source, llvm::Value* destination_value = GetEmittedValueFor(&destination); int64 source_size = ByteSizeOf(source.shape()); // TODO(b/63762267): Be more aggressive about specifying alignment. - b_.CreateMemCpy(destination_value, /*DstAlign=*/1, source_value, - /*SrcAlign=*/1, source_size); + MemCpy(destination_value, /*DstAlign=*/1, source_value, + /*SrcAlign=*/1, source_size); return Status::OK(); } Status IrEmitter::ElementTypesSameAndSupported( const HloInstruction& instruction, - gtl::ArraySlice operands, - gtl::ArraySlice supported_types) { + absl::Span operands, + absl::Span supported_types) { for (auto operand : operands) { TF_RET_CHECK( ShapeUtil::SameElementType(operands[0]->shape(), operand->shape())); @@ -2826,8 +2967,8 @@ Status IrEmitter::ElementTypesSameAndSupported( if (std::find(supported_types.begin(), supported_types.end(), primitive_type) == supported_types.end()) { return Unimplemented("unsupported operand type %s in op %s", - PrimitiveType_Name(primitive_type).c_str(), - HloOpcodeString(instruction.opcode()).c_str()); + PrimitiveType_Name(primitive_type), + HloOpcodeString(instruction.opcode())); } return Status::OK(); } @@ -2845,9 +2986,10 @@ Status IrEmitter::DefaultAction(HloInstruction* hlo) { } llvm::Value* IrEmitter::EmitThreadLocalCall( - const HloComputation& callee, - tensorflow::gtl::ArraySlice parameters, + const HloComputation& callee, absl::Span parameters, absl::string_view name) { + CHECK(absl::c_binary_search(thread_local_computations_, &callee)); + const Shape& return_shape = callee.root_instruction()->shape(); // Lifting this restriction to allow "small" arrays should be easy. Allowing @@ -2862,7 +3004,7 @@ llvm::Value* IrEmitter::EmitThreadLocalCall( CHECK(!parameter->getType()->isPointerTy()); llvm::Value* parameter_addr = llvm_ir::EmitAllocaAtFunctionEntry( parameter->getType(), "arg_addr", &b_); - b_.CreateStore(parameter, parameter_addr); + Store(parameter, parameter_addr); parameter_addrs.push_back(parameter_addr); } @@ -2871,29 +3013,30 @@ llvm::Value* IrEmitter::EmitThreadLocalCall( absl::StrCat(name, "_retval_addr"), &b_, MinimumAlignmentForPrimitiveType(return_type)); - b_.CreateCall( - FindOrDie(emitted_functions_, &callee), - GetArrayFunctionCallArguments( - parameter_addrs, &b_, name, - /*return_value_buffer=*/return_value_buffer, - /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(), - /*temp_buffers_arg=*/ - llvm::Constant::getNullValue(b_.getInt8PtrTy()->getPointerTo()), - /*profile_counters_arg=*/GetProfileCountersArgument())); + Call(FindOrDie(emitted_functions_, &callee), + GetArrayFunctionCallArguments( + parameter_addrs, &b_, name, + /*return_value_buffer=*/return_value_buffer, + /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(), + /*buffer_table_arg=*/ + llvm::Constant::getNullValue(b_.getInt8PtrTy()->getPointerTo()), + /*profile_counters_arg=*/GetProfileCountersArgument())); - return b_.CreateLoad(return_value_buffer); + return Load(return_value_buffer); } void IrEmitter::EmitGlobalCall(const HloComputation& callee, absl::string_view name) { - b_.CreateCall(FindOrDie(emitted_functions_, &callee), - GetArrayFunctionCallArguments( - /*parameter_addresses=*/{}, &b_, name, - /*return_value_buffer=*/ - llvm::Constant::getNullValue(b_.getInt8PtrTy()), - /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(), - /*temp_buffers_arg=*/GetTempBuffersArgument(), - /*profile_counters_arg=*/GetProfileCountersArgument())); + CHECK(absl::c_binary_search(global_computations_, &callee)); + + Call(FindOrDie(emitted_functions_, &callee), + GetArrayFunctionCallArguments( + /*parameter_addresses=*/{}, &b_, name, + /*return_value_buffer=*/ + llvm::Constant::getNullValue(b_.getInt8PtrTy()), + /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(), + /*buffer_table_arg=*/GetBufferTableArgument(), + /*profile_counters_arg=*/GetProfileCountersArgument())); } llvm::Value* IrEmitter::GetBufferForGlobalCallReturnValue( @@ -2905,7 +3048,7 @@ llvm::Value* IrEmitter::GetBufferForGlobalCallReturnValue( const BufferAllocation::Slice root_buffer = assignment_.GetUniqueTopLevelSlice(root_inst).ValueOrDie(); - return EmitTempBufferPointer(root_buffer, root_inst->shape()); + return EmitBufferPointer(root_buffer, root_inst->shape()); } } // namespace cpu diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h index 99c080b3dbbaf528d938385210eacd8d59163557..586f27b104ed706a3b128903c6a90abbf3667e59 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h @@ -23,7 +23,9 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" #include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "llvm/ADT/Triple.h" #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" @@ -40,13 +42,12 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h" #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" +#include "tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h" #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h" #include "tensorflow/compiler/xla/service/name_uniquer.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" @@ -55,13 +56,14 @@ namespace cpu { // This class is the top-level API for the XLA HLO --> LLVM IR compiler. It // implements the DfsHloVisitor interface and emits HLO computations as LLVM IR // functions. -class IrEmitter : public DfsHloVisitorWithDefault { +class IrEmitter : public DfsHloVisitorWithDefault, + public IrBuilderMixin { public: // Create a new LLVM IR emitter. // // hlo_module: the HLO module we are emitting IR for. - // assignment: a BufferAssignment from which we know which temporary buffers - // are used by the HLO nodes. + // assignment: a BufferAssignment from which we know which buffers are used by + // the HLO nodes. // llvm_module: the LLVM module to emit IR into. // instruction_to_profile_idx: the mapping from HLO instructions to their // index in the profiling array. @@ -96,17 +98,20 @@ class IrEmitter : public DfsHloVisitorWithDefault { StatusOr EmitComputation( HloComputation* computation, const string& function_name_prefix, bool is_top_level_computation, - std::vector* instruction_order); + const std::vector* instruction_order); llvm::IRBuilder<>* b() { return &b_; } + // builder() is for IrBuilderMixin. + llvm::IRBuilder<>* builder() { return &b_; } + // Emit an LLVM global variable for every constant buffer allocation. Status EmitConstantGlobals(); // Emit code to map one element according to `map_instr`. llvm::Value* EmitElementalMap( const HloMapInstruction& map_instr, - tensorflow::gtl::ArraySlice elemental_operands, + absl::Span elemental_operands, absl::string_view name); protected: @@ -152,13 +157,18 @@ class IrEmitter : public DfsHloVisitorWithDefault { Status HandleConditional(HloInstruction* conditional) override; Status HandleScatter(HloInstruction* scatter) override; Status HandleAfterAll(HloInstruction* gen_token) override; - Status HandleIota(HloInstruction* iota) override; Status HandleRng(HloInstruction* rng) override; Status FinishVisit(HloInstruction* root) override; Status Preprocess(HloInstruction* hlo) override; Status Postprocess(HloInstruction* hlo) override; + // A convenient helper for calling BufferAssignment::GetUniqueSlice. + BufferAllocation::Slice GetAllocationSlice( + const HloInstruction& hlo, const ShapeIndex& index = {}) const { + return assignment_.GetUniqueSlice(&hlo, index).ConsumeValueOrDie(); + } + private: // Private helper to initialize an IR function for the computation. void InitializeIrFunction(const string& function_name); @@ -215,24 +225,21 @@ class IrEmitter : public DfsHloVisitorWithDefault { // argument of the computation function being emitted by this emitter. llvm::Value* GetExecutableRunOptionsArgument(); - // Get the llvm::Value* that represents the "temps" argument of the + // Get the llvm::Value* that represents the "buffer_table" argument of the // computation function being emitted by this emitter. - llvm::Value* GetTempBuffersArgument(); + llvm::Value* GetBufferTableArgument(); - // Helper for EmitTempBufferPointer. - llvm::Value* EmitGlobalTempBufferPointer(const BufferAllocation::Slice& slice, - const Shape& target_shape); + // Helper for EmitBufferPointer. + llvm::Value* EmitGlobalBufferPointer(const BufferAllocation::Slice& slice, + const Shape& target_shape); - // Helper for EmitTempBufferPointer. - llvm::Value* EmitThreadLocalTempBufferPointer( + // Helper for EmitBufferPointer. + llvm::Value* EmitThreadLocalBufferPointer( const BufferAllocation::Slice& slice, const Shape& target_shape); // Emits code that computes the address of the given buffer allocation slice. - // - // TODO(sanjoy): This should be renamed to reflect that it no longer provides - // access to just temporaries. - llvm::Value* EmitTempBufferPointer(const BufferAllocation::Slice& slice, - const Shape& target_shape); + llvm::Value* EmitBufferPointer(const BufferAllocation::Slice& slice, + const Shape& target_shape); // Emits a function into the current module. This can be used for // computations embedded inside other computations, such as the @@ -248,10 +255,9 @@ class IrEmitter : public DfsHloVisitorWithDefault { // // `parameters` holds the *scalar values* that need to be passed to the // callee. The return value is the scalar returned by the callee. - llvm::Value* EmitThreadLocalCall( - const HloComputation& callee, - tensorflow::gtl::ArraySlice parameters, - absl::string_view name); + llvm::Value* EmitThreadLocalCall(const HloComputation& callee, + absl::Span parameters, + absl::string_view name); // Emits a call to a "global" function (e.g. to the computation nested within // a kWhile or a kCall). Buffer assignment unabiguously assignes buffers to @@ -267,8 +273,8 @@ class IrEmitter : public DfsHloVisitorWithDefault { // match and are of one of the given supported types. Status ElementTypesSameAndSupported( const HloInstruction& instruction, - tensorflow::gtl::ArraySlice operands, - tensorflow::gtl::ArraySlice supported_types); + absl::Span operands, + absl::Span supported_types); // Emit IR to perform a computation for every element in the given target op. // This produces a series of nested loops (one for each dimension of the op's @@ -315,10 +321,12 @@ class IrEmitter : public DfsHloVisitorWithDefault { // concepts that generalize over other vectorizable operations. We should // consider pulling out these abstractions into a VectorizingIrEmitter or // something similar. - StatusOr EmitVectorizedReduce( - HloInstruction* reduce, HloInstruction* arg, HloInstruction* init_value, - tensorflow::gtl::ArraySlice dimensions, HloComputation* function, - string* failure_reason); + StatusOr EmitVectorizedReduce(HloInstruction* reduce, + HloInstruction* arg, + HloInstruction* init_value, + absl::Span dimensions, + HloComputation* function, + string* failure_reason); // We'd like to keep one or two one cache-line's worth of data in registers // without generating IR with illegal (e.g. excessively large or @@ -368,16 +376,15 @@ class IrEmitter : public DfsHloVisitorWithDefault { const ReductionGenerator& reduction_generator, const llvm_ir::IrArray::Index& output_index, const ShardedVectorType& accumulator_type, HloInstruction* init_value, - HloInstruction* arg, tensorflow::gtl::ArraySlice dimensions, + HloInstruction* arg, absl::Span dimensions, unsigned element_alignment); // Tries to emit a fast concatenate operation using memcpy. Returns true if // successful, and false on failure. On failure, sets "failure_reason" to a // string describing why it could not emit a fast concatenate. - StatusOr EmitFastConcatenate( - HloInstruction* concatenate, - tensorflow::gtl::ArraySlice operands, - string* failure_reason); + StatusOr EmitFastConcatenate(HloInstruction* concatenate, + absl::Span operands, + string* failure_reason); // Emits LLVM IR to transfer "element_count" elements of type "primitive_type" // from the address "source" to the address "target". @@ -386,8 +393,8 @@ class IrEmitter : public DfsHloVisitorWithDefault { const llvm_ir::IrArray& target_array, const llvm_ir::IrArray& source_array); - // Assignment of the temporary buffers needed by the computation and their - // shape information. + // Assignment of the buffers needed by the computation and their shape + // information. const BufferAssignment& assignment_; // The LLVM module into which IR will be emitted. @@ -420,7 +427,7 @@ class IrEmitter : public DfsHloVisitorWithDefault { // Maps the buffer allocation slices for the parameters to the computation // being compiled to their parameter numbers. Only relevant for thread local // computations. - tensorflow::gtl::FlatMap + absl::flat_hash_map computation_parameter_allocations_; // Maps HLO instructions to their index into the profile counter array. @@ -560,13 +567,16 @@ class IrEmitter : public DfsHloVisitorWithDefault { } }; - tensorflow::gtl::FlatMap + absl::flat_hash_map emitted_literals_; - tensorflow::gtl::FlatMap + absl::flat_hash_map constant_buffer_to_global_; + std::vector thread_local_computations_; + std::vector global_computations_; + TF_DISALLOW_COPY_AND_ASSIGN(IrEmitter); }; diff --git a/tensorflow/compiler/xla/service/cpu/ir_function.cc b/tensorflow/compiler/xla/service/cpu/ir_function.cc index 784045313dfa2d44da64c6b50be80258c5e8466a..adfb8392bf6fa356f0a5cdab3ff74036eca8918e 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_function.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_function.cc @@ -78,19 +78,20 @@ void IrFunction::Initialize(const string& function_name, const bool optimize_for_size_requested, const bool enable_fast_math) { // The function signature is: - // void function(i8* retval, i8* run_options, i8** params, i8** temps, + // void function(i8* retval, i8* run_options, i8** params, i8** + // buffer_table, // i64* dynamic_loop_bounds, i64* prof_counters) // // For thread local functions: // retval: points to the returned value. // params: address of an array with pointers to parameters. - // temps: is null + // buffer_table: is null // // For global functions: // retval: is null // params: is null - // temps: address of an array with pointers to temporary buffers and entry - // computation parameters. + // buffer_table: address of an array with pointers to temporary buffers and + // entry computation parameters (but not to constant buffers). // // Therefore, the generated function's signature (FunctionType) is statically // determined - parameter unpacking is done in code generated into the @@ -116,7 +117,7 @@ void IrFunction::Initialize(const string& function_name, // \---------/ \---------/ \-----------/ // // /---------------------------------------------\ - // temps ---------> | temp 0 | temp 1 | ..... | temp N-1 | + // buffer_table---> | buff 0 | guff 1 | ..... | buff N-1 | // | addr | addr | | addr | // \---------------------------------------------/ // | | | @@ -134,9 +135,9 @@ void IrFunction::Initialize(const string& function_name, // prof counters -> | counter 0 | counter 1 | ..... | counter N-1 | // \---------------------------------------------/ - // Even though the type of params and temps is void** in the host's view, in - // LLVM IR this is represented by i8*, similarly to void*. It's up to the code - // to use GEPs to unravel the indirection layers. + // Even though the type of params and buffer_table is void** in the host's + // view, in LLVM IR this is represented by i8*, similarly to void*. It's up to + // the code to use GEPs to unravel the indirection layers. llvm::FunctionType* function_type = llvm::FunctionType::get( /*Result=*/llvm::Type::getVoidTy(llvm_module_->getContext()), /*Params=*/ @@ -160,8 +161,8 @@ void IrFunction::Initialize(const string& function_name, exec_run_options_arg_ = &*arg_iter; (++arg_iter)->setName("params"); parameters_arg_ = &*arg_iter; - (++arg_iter)->setName("temps"); - temp_buffers_arg_ = &*arg_iter; + (++arg_iter)->setName("buffer_table"); + buffer_table_arg_ = &*arg_iter; if (num_dynamic_loop_bounds_ > 0) { (++arg_iter)->setName("dynamic_loop_bounds"); dynamic_loop_bounds_arg_ = &*arg_iter; @@ -200,10 +201,10 @@ llvm::Value* IrFunction::GetDynamicLoopBound(const int64 offset) { // Returns an array of compute function call arguments (including parameter // address buffer). std::vector GetArrayFunctionCallArguments( - tensorflow::gtl::ArraySlice parameter_addresses, - llvm::IRBuilder<>* b, absl::string_view name, - llvm::Value* return_value_buffer, llvm::Value* exec_run_options_arg, - llvm::Value* temp_buffers_arg, llvm::Value* profile_counters_arg) { + absl::Span parameter_addresses, llvm::IRBuilder<>* b, + absl::string_view name, llvm::Value* return_value_buffer, + llvm::Value* exec_run_options_arg, llvm::Value* buffer_table_arg, + llvm::Value* profile_counters_arg) { llvm::Value* parameter_addresses_buffer; if (parameter_addresses.empty()) { @@ -230,7 +231,7 @@ std::vector GetArrayFunctionCallArguments( }; std::vector arguments{ to_int8_ptr(return_value_buffer), to_int8_ptr(exec_run_options_arg), - parameter_addresses_buffer, temp_buffers_arg}; + parameter_addresses_buffer, buffer_table_arg}; if (profile_counters_arg != nullptr) { arguments.push_back(profile_counters_arg); } diff --git a/tensorflow/compiler/xla/service/cpu/ir_function.h b/tensorflow/compiler/xla/service/cpu/ir_function.h index ee7595f6e9706902a3e6b4b2e7e38c3f022abca3..623a5f185fa1fd0526bc8664e2ba11c9dde79b1d 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_function.h +++ b/tensorflow/compiler/xla/service/cpu/ir_function.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_FUNCTION_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_FUNCTION_H_ +#include "absl/types/span.h" #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Module.h" @@ -24,7 +25,6 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/gtl/array_slice.h" namespace xla { namespace cpu { @@ -80,8 +80,9 @@ class IrFunction { // Get the llvm::Value* that represents this functions parameters argument. llvm::Value* parameters_arg() { return parameters_arg_; } - // Get the llvm::Value* that represents this functions "temps" argument. - llvm::Value* temp_buffers_arg() { return temp_buffers_arg_; } + // Get the llvm::Value* that represents this functions "buffer_table" + // argument. + llvm::Value* buffer_table_arg() { return buffer_table_arg_; } // Get the llvm::Value* that represents this functions "prof_counters" // argument. @@ -108,17 +109,17 @@ class IrFunction { llvm::Argument* result_arg_; llvm::Value* exec_run_options_arg_; llvm::Value* parameters_arg_; - llvm::Value* temp_buffers_arg_; + llvm::Value* buffer_table_arg_; llvm::Value* dynamic_loop_bounds_arg_ = nullptr; llvm::Value* profile_counters_arg_; }; // Returns an array of compute function call argument ir values. std::vector GetArrayFunctionCallArguments( - tensorflow::gtl::ArraySlice parameter_addresses, - llvm::IRBuilder<>* b, absl::string_view name, - llvm::Value* return_value_buffer, llvm::Value* exec_run_options_arg, - llvm::Value* temp_buffers_arg, llvm::Value* profile_counters_arg); + absl::Span parameter_addresses, llvm::IRBuilder<>* b, + absl::string_view name, llvm::Value* return_value_buffer, + llvm::Value* exec_run_options_arg, llvm::Value* buffer_table_arg, + llvm::Value* profile_counters_arg); // Emits a call to a runtime fork/join function which dispatches parallel // calls to 'parallel_function' (and joins threads before returning). diff --git a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc index aedb069dce5419ce02c67009a834d59c91e469b5..f8441c3e345504616485c6b34b4302acd5cc23a3 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc @@ -15,9 +15,9 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h" +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" -#include "tensorflow/core/lib/strings/stringprintf.h" namespace xla { namespace cpu { @@ -52,15 +52,15 @@ ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(absl::string_view loop_name, llvm::Value* end_index = (*dynamic_loop_bounds_)[bounds_index].second; std::unique_ptr loop = loop_nest.AddLoop( - /*suffix=*/tensorflow::strings::Printf("dim.%lld", dimension), - start_index, end_index); + /*suffix=*/absl::StrFormat("dim.%d", dimension), start_index, + end_index); array_index[dimension] = loop->GetIndVarValue(); } else { // Emit static loop bounds for this dimension. std::unique_ptr loop = loop_nest.AddLoop( /*start_index=*/0, /*end_index=*/shape_.dimensions(dimension), - /*suffix=*/tensorflow::strings::Printf("dim.%lld", dimension)); + /*suffix=*/absl::StrFormat("dim.%d", dimension)); array_index[dimension] = loop->GetIndVarValue(); } } diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc index b4c0c09ec06bac9b5e228428c072948afdd4a547..ede7f433ca6b2cc5629115f800348be9dfb2b93b 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc @@ -142,6 +142,7 @@ int64 ParallelTaskAssignment::GetTargetParallelTaskCount( opcode == HloOpcode::kGetTupleElement || opcode == HloOpcode::kBitcast || opcode == HloOpcode::kFft || opcode == HloOpcode::kInfeed || opcode == HloOpcode::kOutfeed || opcode == HloOpcode::kRng || + opcode == HloOpcode::kSort || (opcode == HloOpcode::kConvolution && PotentiallyImplementedAsEigenConvolution(*instruction, target_machine_features_)) || diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h index a99cd99c14abb66fc426c43656520e01f34a1700..3822d5300e30704f68b2cf0c7f0b77d595c17a25 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h @@ -60,7 +60,7 @@ class ParallelTaskAssignment { // own embedded computation, which is compiled as a parallel compute function, // and which is invoked from a kCall instruction that is lowered in codegen to // a runtime parallel fork/join call. -class ParallelTaskAssigner : public HloPassInterface { +class ParallelTaskAssigner : public HloModulePass { public: // 'max_parallelism': the maximum parallel task count per instruction. // 'shape_size': shape size function used by HloCostAnalysis during parallel diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc index a84ee78b19981e480858320e445de7f5dae27d61..fad76338a57cd9eb21d9469ca8552efa8ea0129b 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc @@ -35,9 +35,7 @@ class ParallelTaskAssignmentTest : public HloVerifiedTestBase { cpu::TargetMachineFeaturesWithFakeAlignmentLogic target_machine_features_; ParallelTaskAssignmentTest() - : HloVerifiedTestBase(/*layout_sensitive=*/false, - /*allow_mixed_precision=*/false), - target_machine_features_([](int64 shape_size) { + : HloVerifiedTestBase(), target_machine_features_([](int64 shape_size) { return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment; }) {} diff --git a/tensorflow/compiler/xla/service/cpu/runtime_fork_join.cc b/tensorflow/compiler/xla/service/cpu/runtime_fork_join.cc index a5f34908d70dd18ec017bdf9833c7df40f80db07..2d9492eacfea34bec3b0f1115e171a5328b7cdc3 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_fork_join.cc +++ b/tensorflow/compiler/xla/service/cpu/runtime_fork_join.cc @@ -61,7 +61,7 @@ using ComputeFunctionType = void (*)(void*, const void*, const void**, void**, // TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_ParallelForkJoin( void* result_ptr, const void* run_options_ptr, const void** params, - void** temps, uint64* prof_counters, int32 num_partitions, + void** buffer_table, uint64* prof_counters, int32 num_partitions, int64* partitions, int32 num_partitioned_dims, void* function_ptr) { VLOG(2) << "ParallelForkJoin ENTRY" << " num_partitions: " << num_partitions @@ -81,9 +81,9 @@ TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_ParallelForkJoin( for (int32 i = 1; i < num_partitions; ++i) { const int64 offset = i * stride; run_options->intra_op_thread_pool()->enqueueNoNotification( - [i, function, result_ptr, run_options_ptr, temps, prof_counters, + [i, function, result_ptr, run_options_ptr, buffer_table, prof_counters, partitions, offset, &bc]() { - function(result_ptr, run_options_ptr, nullptr, temps, + function(result_ptr, run_options_ptr, nullptr, buffer_table, &partitions[offset], prof_counters); bc.DecrementCount(); VLOG(3) << "ParallelForkJoin partition " << i << " done."; @@ -91,7 +91,7 @@ TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_ParallelForkJoin( } // Call first compute function inline. - function(result_ptr, run_options_ptr, params, temps, &partitions[0], + function(result_ptr, run_options_ptr, params, buffer_table, &partitions[0], prof_counters); VLOG(3) << "ParallelForkJoin partition 0 done."; bc.Wait(); diff --git a/tensorflow/compiler/xla/service/cpu/runtime_fork_join.h b/tensorflow/compiler/xla/service/cpu/runtime_fork_join.h index 1cf0ec6e3df400e35fa4e755a0b25b4ce7966e8f..a279c7d2d61bdd138f5285a8c8ccc89d22db9692 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_fork_join.h +++ b/tensorflow/compiler/xla/service/cpu/runtime_fork_join.h @@ -24,7 +24,7 @@ extern "C" { // threads before returning. See comments in runtime_fork_join.cc for details. extern void __xla_cpu_runtime_ParallelForkJoin( void* result_ptr, const void* run_options_ptr, const void** params, - void** temps, tensorflow::uint64* prof_counters, + void** buffer_table, tensorflow::uint64* prof_counters, tensorflow::int32 num_partitions, tensorflow::int64* partitions, tensorflow::int32 num_partitioned_dims, void* function_ptr); diff --git a/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc b/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc new file mode 100644 index 0000000000000000000000000000000000000000..e0e7deb98e579c090c8fae320a3ba8a3ce8dbe5f --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc @@ -0,0 +1,236 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h" + +#include +#include +#include +#include +#include +#include + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/platform/dynamic_annotations.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" + +namespace { +using tensorflow::int16; +using tensorflow::int32; +using tensorflow::int64; +using tensorflow::int8; +using tensorflow::uint16; +using tensorflow::uint32; +using tensorflow::uint64; +using tensorflow::uint8; + +template +void KeyValueSort(std::pair* row_to_sort, int64 num_elements) { + std::sort(row_to_sort, row_to_sort + num_elements); +} + +// For floating point numbers, we want a total order comparator. -NaN and NaN +// should appear at the beginning and end of the ordering, and -0.0 should +// appear before 0.0. Also we want to have a stable sort, so if the keys are the +// same, we compare the index values. +template +bool LessThan(KeyType lhs, int64 lhs_index, KeyType rhs, int64 rhs_index) { + bool lhs_is_negative = std::signbit(lhs); + bool rhs_is_negative = std::signbit(rhs); + // If the signs are different, we can just compare the signs. + if (lhs_is_negative != rhs_is_negative) { + return lhs_is_negative && !rhs_is_negative; + } + bool lhs_nan = std::isnan(lhs); + bool rhs_nan = std::isnan(rhs); + // Exactly one number is nan? + if (lhs_nan != rhs_nan) { + if (lhs_nan) { + return lhs_is_negative; + } + return !rhs_is_negative; + } + if (lhs != rhs) { + return lhs < rhs; + } + return lhs_index < rhs_index; +} + +template <> +void KeyValueSort(std::pair* row_to_sort, int64 num_elements) { + std::sort(row_to_sort, row_to_sort + num_elements, + [](const std::pair& lhs, + const std::pair& rhs) -> bool { + return LessThan(lhs.first, lhs.second, rhs.first, rhs.second); + }); +} + +template <> +void KeyValueSort(std::pair* row_to_sort, int64 num_elements) { + std::sort(row_to_sort, row_to_sort + num_elements, + [](const std::pair& lhs, + const std::pair& rhs) -> bool { + return LessThan(lhs.first, lhs.second, rhs.first, rhs.second); + }); +} + +template <> +void KeyValueSort(std::pair* row_to_sort, + int64 num_elements) { + std::sort(row_to_sort, row_to_sort + num_elements, + [](const std::pair& lhs, + const std::pair& rhs) -> bool { + return LessThan( + Eigen::half_impl::half_to_float(lhs.first), lhs.second, + Eigen::half_impl::half_to_float(rhs.first), rhs.second); + }); +} + +template +void KeyValueSortImpl(KeyType* keys, int64 a, int64 b, int64 c, char* values, + int32 values_primitive_type_size_in_bytes) { + // High-level idea of the iteration/sorting logic: + // Conceptually we have a 3-dimensional shape [a, b, c]. b corresponds to the + // dimension to sort, c is the product of the more minor dimensions (set to 1 + // if b is the most minor dimension), and a is the product of the more major + // dimensions (set to 1 if b is the most major dimension). There are a * c + // many rows that we need to sort. We iterate through these, calculate a + // 'base_offset' value which points to the first element in that row, and add + // i * c for accessing the 'i'-th element in that row. + + int64 sort_dimension_elements = b; + int64 num_iteration_elements = a * c; + int64 sort_dimension_offset = c; + + std::unique_ptr[]> row_to_sort( + new std::pair[sort_dimension_elements]); + std::unique_ptr reordered_values( + new std::string[sort_dimension_elements]); + for (int64 index = 0; index < num_iteration_elements; ++index) { + // 'index' can be split into two values which index into the 'c' dimension + // and the 'a' dimension, respectively. 'index' % 'c' is the index into the + // 'c' dimension, 'index' / 'c' is the index into the 'a' dimension. When + // calculating the base offset, we need to multiply the index into the 'a' + // dimension with 'b' * 'c'. + // 'index' / 'c' * 'c' * 'b' = ('index' - 'index' % 'c') * 'b'. + int64 base_offset = + index % sort_dimension_offset + + (index - index % sort_dimension_offset) * sort_dimension_elements; + // TODO(b/26783907): We could define a custom iterator class that references + // both arrays. Then we could avoid the intermediate copy. However this + // would become more complicated, and it is not clear if the benefit is high + // enough. + for (int64 i = 0; i < sort_dimension_elements; ++i) { + row_to_sort[i] = + std::make_pair(keys[base_offset + i * sort_dimension_offset], i); + } + KeyValueSort(row_to_sort.get(), sort_dimension_elements); + for (int64 i = 0; i < sort_dimension_elements; ++i) { + keys[base_offset + i * sort_dimension_offset] = row_to_sort[i].first; + } + if (values == nullptr) { + continue; + } + + // Reorder the values according to the order defined by the keys. + for (int64 i = 0; i < sort_dimension_elements; ++i) { + int64 memory_index = + (base_offset + row_to_sort[i].second * sort_dimension_offset) * + values_primitive_type_size_in_bytes; + + reordered_values[i] = std::string(values + memory_index, + values_primitive_type_size_in_bytes); + } + for (int64 i = 0; i < sort_dimension_elements; ++i) { + int64 memory_index = (base_offset + i * sort_dimension_offset) * + values_primitive_type_size_in_bytes; + memcpy(values + memory_index, reordered_values[i].c_str(), + values_primitive_type_size_in_bytes); + } + } +} +} // namespace + +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortPRED( + bool* keys, int64 a, int64 b, int64 c, char* values, + int32 values_primitive_type_size_in_bytes) { + KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes); +} + +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortS8( + int8* keys, int64 a, int64 b, int64 c, char* values, + int32 values_primitive_type_size_in_bytes) { + KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes); +} + +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortU8( + uint8* keys, int64 a, int64 b, int64 c, char* values, + int32 values_primitive_type_size_in_bytes) { + KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes); +} + +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortS16( + int16* keys, int64 a, int64 b, int64 c, char* values, + int32 values_primitive_type_size_in_bytes) { + KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes); +} + +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortU16( + uint16* keys, int64 a, int64 b, int64 c, char* values, + int32 values_primitive_type_size_in_bytes) { + KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes); +} + +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortF16( + Eigen::half* keys, int64 a, int64 b, int64 c, char* values, + int32 values_primitive_type_size_in_bytes) { + KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes); +} + +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortS32( + int32* keys, int64 a, int64 b, int64 c, char* values, + int32 values_primitive_type_size_in_bytes) { + KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes); +} + +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortU32( + uint32* keys, int64 a, int64 b, int64 c, char* values, + int32 values_primitive_type_size_in_bytes) { + KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes); +} + +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortF32( + float* keys, int64 a, int64 b, int64 c, char* values, + int32 values_primitive_type_size_in_bytes) { + KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes); +} + +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortS64( + int64* keys, int64 a, int64 b, int64 c, char* values, + int32 values_primitive_type_size_in_bytes) { + KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes); +} + +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortU64( + uint64* keys, int64 a, int64 b, int64 c, char* values, + int32 values_primitive_type_size_in_bytes) { + KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes); +} + +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortF64( + double* keys, int64 a, int64 b, int64 c, char* values, + int32 values_primitive_type_size_in_bytes) { + KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes); +} diff --git a/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h b/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h new file mode 100644 index 0000000000000000000000000000000000000000..28e35e82c18cbf078f8a1e7f5b818bf839d3d3df --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h @@ -0,0 +1,88 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_KEY_VALUE_SORT_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_KEY_VALUE_SORT_H_ + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/platform/types.h" + +extern "C" { + +// 'keys' represents a 3-dimensional shape with dimensions [a, b, c]. The 'b' +// dimension of 'keys' is sorted into ascending order. 'values' can be nullptr. +// If 'values' is not nullptr, the elements in 'values' are reordered in such a +// way that if the element at index 'i' in 'keys' was moved to index 'j', the +// element at index 'i' in 'values' is also moved to index 'j' (which means that +// the same elements correspond to each other as before). +extern void __xla_cpu_runtime_KeyValueSortPRED( + bool* keys, tensorflow::int64 a, tensorflow::int64 b, tensorflow::int64 c, + char* values, tensorflow::int32 values_primitive_type_size_in_bytes); + +extern void __xla_cpu_runtime_KeyValueSortS8( + tensorflow::int8* keys, tensorflow::int64 a, tensorflow::int64 b, + tensorflow::int64 c, char* values, + tensorflow::int32 values_primitive_type_size_in_bytes); + +extern void __xla_cpu_runtime_KeyValueSortU8( + tensorflow::uint8* keys, tensorflow::int64 a, tensorflow::int64 b, + tensorflow::int64 c, char* values, + tensorflow::int32 values_primitive_type_size_in_bytes); + +extern void __xla_cpu_runtime_KeyValueSortS16( + tensorflow::int16* keys, tensorflow::int64 a, tensorflow::int64 b, + tensorflow::int64 c, char* values, + tensorflow::int32 values_primitive_type_size_in_bytes); + +extern void __xla_cpu_runtime_KeyValueSortU16( + tensorflow::uint16* keys, tensorflow::int64 a, tensorflow::int64 b, + tensorflow::int64 c, char* values, + tensorflow::int32 values_primitive_type_size_in_bytes); + +extern void __xla_cpu_runtime_KeyValueSortF16( + Eigen::half* keys, tensorflow::int64 a, tensorflow::int64 b, + tensorflow::int64 c, char* values, + tensorflow::int32 values_primitive_type_size_in_bytes); + +extern void __xla_cpu_runtime_KeyValueSortS32( + tensorflow::int32* keys, tensorflow::int64 a, tensorflow::int64 b, + tensorflow::int64 c, char* values, + tensorflow::int32 values_primitive_type_size_in_bytes); + +extern void __xla_cpu_runtime_KeyValueSortU32( + tensorflow::uint32* keys, tensorflow::int64 a, tensorflow::int64 b, + tensorflow::int64 c, char* values, + tensorflow::int32 values_primitive_type_size_in_bytes); + +extern void __xla_cpu_runtime_KeyValueSortF32( + float* keys, tensorflow::int64 a, tensorflow::int64 b, tensorflow::int64 c, + char* values, tensorflow::int32 values_primitive_type_size_in_bytes); + +extern void __xla_cpu_runtime_KeyValueSortS64( + tensorflow::int64* keys, tensorflow::int64 a, tensorflow::int64 b, + tensorflow::int64 c, char* values, + tensorflow::int32 values_primitive_type_size_in_bytes); + +extern void __xla_cpu_runtime_KeyValueSortU64( + tensorflow::uint64* keys, tensorflow::int64 a, tensorflow::int64 b, + tensorflow::int64 c, char* values, + tensorflow::int32 values_primitive_type_size_in_bytes); + +extern void __xla_cpu_runtime_KeyValueSortF64( + double* keys, tensorflow::int64 a, tensorflow::int64 b, tensorflow::int64 c, + char* values, tensorflow::int32 values_primitive_type_size_in_bytes); +} + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_KEY_VALUE_SORT_H_ diff --git a/tensorflow/compiler/xla/service/cpu/sample_harness.cc b/tensorflow/compiler/xla/service/cpu/sample_harness.cc index f227e4ae139b92e56786e38ef8eef72c9e2cd424..55d5925642a97b1a0425c092c82070d4b8e59df3 100644 --- a/tensorflow/compiler/xla/service/cpu/sample_harness.cc +++ b/tensorflow/compiler/xla/service/cpu/sample_harness.cc @@ -16,6 +16,7 @@ limitations under the License. #include #include +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/client_library.h" @@ -27,7 +28,6 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/logging.h" @@ -37,21 +37,20 @@ int main(int argc, char** argv) { xla::LocalClient* client(xla::ClientLibrary::LocalClientOrDie()); // Transfer parameters. - std::unique_ptr param0_literal = + xla::Literal param0_literal = xla::LiteralUtil::CreateR1({1.1f, 2.2f, 3.3f, 5.5f}); std::unique_ptr param0_data = - client->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client->TransferToServer(param0_literal).ConsumeValueOrDie(); - std::unique_ptr param1_literal = - xla::LiteralUtil::CreateR2( - {{3.1f, 4.2f, 7.3f, 9.5f}, {1.1f, 2.2f, 3.3f, 4.4f}}); + xla::Literal param1_literal = xla::LiteralUtil::CreateR2( + {{3.1f, 4.2f, 7.3f, 9.5f}, {1.1f, 2.2f, 3.3f, 4.4f}}); std::unique_ptr param1_data = - client->TransferToServer(*param1_literal).ConsumeValueOrDie(); + client->TransferToServer(param1_literal).ConsumeValueOrDie(); // Build computation. xla::XlaBuilder builder(""); - auto p0 = Parameter(&builder, 0, param0_literal->shape(), "param0"); - auto p1 = Parameter(&builder, 1, param1_literal->shape(), "param1"); + auto p0 = Parameter(&builder, 0, param0_literal.shape(), "param0"); + auto p1 = Parameter(&builder, 1, param1_literal.shape(), "param1"); Add(p1, p0, {0}); xla::StatusOr computation_status = builder.Build(); @@ -59,17 +58,16 @@ int main(int argc, char** argv) { // Execute and transfer result of computation. xla::ExecutionProfile profile; - xla::StatusOr> result = - client->ExecuteAndTransfer( - computation, - /*arguments=*/{param0_data.get(), param1_data.get()}, - /*execution_options=*/nullptr, - /*execution_profile=*/&profile); - std::unique_ptr actual = result.ConsumeValueOrDie(); + xla::StatusOr result = client->ExecuteAndTransfer( + computation, + /*arguments=*/{param0_data.get(), param1_data.get()}, + /*execution_options=*/nullptr, + /*execution_profile=*/&profile); + xla::Literal actual = result.ConsumeValueOrDie(); - LOG(INFO) << tensorflow::strings::Printf("computation took %lldns", - profile.compute_time_ns()); - LOG(INFO) << actual->ToString(); + LOG(INFO) << absl::StrFormat("computation took %dns", + profile.compute_time_ns()); + LOG(INFO) << actual.ToString(); return 0; } diff --git a/tensorflow/compiler/xla/service/cpu/shape_partition_test.cc b/tensorflow/compiler/xla/service/cpu/shape_partition_test.cc index ae80a6f4977f85cfd9f872734fd0a69432a1f382..1a3d82de954318368d61e3feeb0345dc592dcd8b 100644 --- a/tensorflow/compiler/xla/service/cpu/shape_partition_test.cc +++ b/tensorflow/compiler/xla/service/cpu/shape_partition_test.cc @@ -19,14 +19,14 @@ limitations under the License. #include #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/util.h" namespace xla { namespace cpu { namespace { -class ShapePartitionAssignerTest : public HloTestBase { +class ShapePartitionAssignerTest : public HloVerifiedTestBase { protected: typedef std::vector Vec; @@ -91,7 +91,7 @@ TEST_F(ShapePartitionAssignerTest, Shape532WithLayout201) { expected_partitions); } -class ShapePartitionIteratorTest : public HloTestBase { +class ShapePartitionIteratorTest : public HloVerifiedTestBase { protected: typedef std::vector> Partition; }; @@ -102,22 +102,22 @@ TEST_F(ShapePartitionIteratorTest, Shape53WithLayout10) { { ShapePartitionIterator iterator(shape, {1}); EXPECT_EQ(1, iterator.GetTotalPartitionCount()); - EXPECT_TRUE(ContainersEqual(Partition({{0, 5}}), iterator.GetPartition(0))); + EXPECT_TRUE(absl::c_equal(Partition({{0, 5}}), iterator.GetPartition(0))); } { ShapePartitionIterator iterator(shape, {2}); EXPECT_EQ(2, iterator.GetTotalPartitionCount()); - EXPECT_TRUE(ContainersEqual(Partition({{0, 2}}), iterator.GetPartition(0))); - EXPECT_TRUE(ContainersEqual(Partition({{2, 3}}), iterator.GetPartition(1))); + EXPECT_TRUE(absl::c_equal(Partition({{0, 2}}), iterator.GetPartition(0))); + EXPECT_TRUE(absl::c_equal(Partition({{2, 3}}), iterator.GetPartition(1))); } { ShapePartitionIterator iterator(shape, {3}); EXPECT_EQ(3, iterator.GetTotalPartitionCount()); - EXPECT_TRUE(ContainersEqual(Partition({{0, 1}}), iterator.GetPartition(0))); - EXPECT_TRUE(ContainersEqual(Partition({{1, 1}}), iterator.GetPartition(1))); - EXPECT_TRUE(ContainersEqual(Partition({{2, 3}}), iterator.GetPartition(2))); + EXPECT_TRUE(absl::c_equal(Partition({{0, 1}}), iterator.GetPartition(0))); + EXPECT_TRUE(absl::c_equal(Partition({{1, 1}}), iterator.GetPartition(1))); + EXPECT_TRUE(absl::c_equal(Partition({{2, 3}}), iterator.GetPartition(2))); } } @@ -128,24 +128,24 @@ TEST_F(ShapePartitionIteratorTest, Shape532WithLayout210) { ShapePartitionIterator iterator(shape, {1, 1}); EXPECT_EQ(1, iterator.GetTotalPartitionCount()); EXPECT_TRUE( - ContainersEqual(Partition({{0, 5}, {0, 3}}), iterator.GetPartition(0))); + absl::c_equal(Partition({{0, 5}, {0, 3}}), iterator.GetPartition(0))); } { ShapePartitionIterator iterator(shape, {2, 2}); EXPECT_EQ(4, iterator.GetTotalPartitionCount()); EXPECT_TRUE( - ContainersEqual(Partition({{0, 2}, {0, 1}}), iterator.GetPartition(0))); + absl::c_equal(Partition({{0, 2}, {0, 1}}), iterator.GetPartition(0))); EXPECT_TRUE( - ContainersEqual(Partition({{0, 2}, {1, 2}}), iterator.GetPartition(1))); + absl::c_equal(Partition({{0, 2}, {1, 2}}), iterator.GetPartition(1))); EXPECT_TRUE( - ContainersEqual(Partition({{2, 3}, {0, 1}}), iterator.GetPartition(2))); + absl::c_equal(Partition({{2, 3}, {0, 1}}), iterator.GetPartition(2))); EXPECT_TRUE( - ContainersEqual(Partition({{2, 3}, {1, 2}}), iterator.GetPartition(3))); + absl::c_equal(Partition({{2, 3}, {1, 2}}), iterator.GetPartition(3))); } } -class RandomShapePartitionIteratorTest : public HloTestBase { +class RandomShapePartitionIteratorTest : public HloVerifiedTestBase { protected: typedef std::vector> Partition; RandomShapePartitionIteratorTest() diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc index bf98064647f4c29ba689902da4d737e1922391d3..9ec0c8f65705db335379649def746921e6b05bea 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc @@ -35,6 +35,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/runtime_fft.h" #include "tensorflow/compiler/xla/service/cpu/runtime_fork_join.h" #include "tensorflow/compiler/xla/service/cpu/runtime_fp16.h" +#include "tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h" #include "tensorflow/compiler/xla/service/cpu/runtime_matmul.h" #include "tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.h" #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv2d.h" @@ -202,6 +203,18 @@ bool RegisterKnownJITSymbols() { REGISTER_CPU_RUNTIME_SYMBOL(ParallelForkJoin); REGISTER_CPU_RUNTIME_SYMBOL(ReleaseInfeedBufferAfterDequeue); REGISTER_CPU_RUNTIME_SYMBOL(ReleaseOutfeedBufferAfterPopulation); + REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortPRED); + REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortS8); + REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortU8); + REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortS16); + REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortU16); + REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortF16); + REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortS32); + REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortU32); + REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortF32); + REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortS64); + REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortU64); + REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortF64); registry->Register("__gnu_f2h_ieee", reinterpret_cast(__gnu_f2h_ieee)); registry->Register("__gnu_h2f_ieee", reinterpret_cast(__gnu_h2f_ieee)); diff --git a/tensorflow/compiler/xla/service/cpu/target_machine_features.cc b/tensorflow/compiler/xla/service/cpu/target_machine_features.cc index a0cd8ee2d2be10bcee9c2e216e24908d949e2d7b..5cdac203af2e7a1f8f3aebda965447ba75e9934e 100644 --- a/tensorflow/compiler/xla/service/cpu/target_machine_features.cc +++ b/tensorflow/compiler/xla/service/cpu/target_machine_features.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/cpu/target_machine_features.h" +#include "tensorflow/core/platform/logging.h" namespace xla { namespace cpu { diff --git a/tensorflow/compiler/xla/service/cpu/target_machine_features.h b/tensorflow/compiler/xla/service/cpu/target_machine_features.h index 8b00ae9e47eeed26ffe80707b89593b267e8dbb8..a383b4a4a00f9b8d49a88e8349793a3a90d8da7b 100644 --- a/tensorflow/compiler/xla/service/cpu/target_machine_features.h +++ b/tensorflow/compiler/xla/service/cpu/target_machine_features.h @@ -16,10 +16,10 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_TARGET_MACHINE_FEATURES_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_TARGET_MACHINE_FEATURES_H_ +#include "absl/container/flat_hash_map.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Target/TargetMachine.h" #include "tensorflow/compiler/xla/primitive_util.h" -#include "tensorflow/core/lib/gtl/flatmap.h" namespace xla { namespace cpu { @@ -97,8 +97,7 @@ class LLVMTargetMachineFeatures : public TargetMachineFeatures { // This is mutated from within `GetTargetTransformInfoFor` which is // semantically a getter (and thus `const`); and is therefore declared // mutable. Making this mutable is okay because it has cache semantics. - mutable tensorflow::gtl::FlatMap + mutable absl::flat_hash_map target_transform_info_cache_; llvm::TargetMachine* target_machine_; }; diff --git a/tensorflow/compiler/xla/service/cpu/tests/BUILD b/tensorflow/compiler/xla/service/cpu/tests/BUILD index 2384166fd2002a67a8aa785ad5fb341d037ee01f..4b129c95d46d8b5a119e5d23eef387daf7863cce 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/BUILD +++ b/tensorflow/compiler/xla/service/cpu/tests/BUILD @@ -48,6 +48,7 @@ tf_cc_test( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service/cpu:cpu_instruction_fusion", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:test", "//tensorflow/core:test_main", @@ -121,6 +122,7 @@ tf_cc_test( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service/cpu:cpu_compiler", "//tensorflow/compiler/xla/service/cpu/tests:cpu_codegen_test", + "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", @@ -178,3 +180,17 @@ tf_cc_test( "//tensorflow/core:test_main", ], ) + +tf_cc_test( + name = "cpu_key_value_sort_test", + srcs = ["cpu_key_value_sort_test.cc"], + deps = [ + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/service/cpu:cpu_compiler", + "//tensorflow/compiler/xla/service/cpu/tests:cpu_codegen_test", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc index fcd87b36b32915773546c211d7d2c447a69bef49..18ee25ba9158c28baaf01492c290638b9673f1ec 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h" #include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/core/platform/test.h" namespace xla { @@ -69,8 +70,7 @@ TEST_P(CpuEigenDotOperationTest, SimpleDotOp) { HloInstruction* rhs = builder.AddInstruction( HloInstruction::CreateParameter(1, param_shape, "input")); - builder.AddInstruction( - HloInstruction::CreateCanonicalDot(param_shape, lhs, rhs)); + builder.AddInstruction(CreateCanonicalDot(param_shape, lhs, rhs)); CompileAndCheck(builder.Build(), spec.filecheck_lines); } @@ -87,8 +87,7 @@ TEST_P(CpuEigenDotOperationTest, DotTransposeOp) { HloInstruction* lhs_transposed = builder.AddInstruction( HloInstruction::CreateTranspose(param_shape, lhs, {1, 0})); - builder.AddInstruction( - HloInstruction::CreateCanonicalDot(param_shape, lhs_transposed, rhs)); + builder.AddInstruction(CreateCanonicalDot(param_shape, lhs_transposed, rhs)); CompileAndCheck(builder.Build(), spec.filecheck_lines); } diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc index b68ac67574d0b9f20ecc0370cdaed87d4465b225..1deb412064b02988a8d4a6d726969c948d354d47 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc @@ -25,7 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/test.h" @@ -34,7 +34,7 @@ namespace xla { namespace cpu { namespace { -class CpuFusionTest : public HloTestBase { +class CpuFusionTest : public HloVerifiedTestBase { protected: CpuFusionTest() {} @@ -45,7 +45,7 @@ TEST_F(CpuFusionTest, FuseTwoElementwiseOps) { auto builder = HloComputation::Builder(TestName()); auto input_literal1 = LiteralUtil::CreateR1({1.0, 2.0, 3.0}); auto input_literal2 = LiteralUtil::CreateR1({-2.0, -42.0, 2.0}); - Shape vshape = input_literal1->shape(); + Shape vshape = input_literal1.shape(); auto input1 = builder.AddInstruction( HloInstruction::CreateConstant(std::move(input_literal1))); @@ -61,7 +61,7 @@ TEST_F(CpuFusionTest, FuseTwoElementwiseOps) { module->AddEntryComputation(builder.Build()); CpuInstructionFusion fusion; - EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie()); + EXPECT_TRUE(fusion.Run(module).ValueOrDie()); // The computation root instruction was fused. Verify the fusion instruction // is now the root. @@ -75,16 +75,16 @@ TEST_F(CpuFusionTest, FuseTwoElementwiseOps) { EXPECT_EQ(4, fusion_instruction->fused_instruction_count()); // Compile and execute the computation. - auto result = ExecuteAndTransfer(std::move(module), {}); + auto result = ExecuteAndTransfer(module->Clone(), {}); // Check the output correctness. - LiteralTestUtil::ExpectR1Near({1.0, 40.0, -5.0}, *result, error_spec_); + LiteralTestUtil::ExpectR1Near({1.0, 40.0, -5.0}, result, error_spec_); } TEST_F(CpuFusionTest, FuseElementwiseOpChain) { auto builder = HloComputation::Builder(TestName()); auto input_literal = LiteralUtil::CreateR1({-1.5, -2.5, -3.0}); - Shape vshape = input_literal->shape(); + Shape vshape = input_literal.shape(); auto input = builder.AddInstruction( HloInstruction::CreateConstant(std::move(input_literal))); @@ -108,7 +108,7 @@ TEST_F(CpuFusionTest, FuseElementwiseOpChain) { module->AddEntryComputation(builder.Build()); CpuInstructionFusion fusion; - EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie()); + EXPECT_TRUE(fusion.Run(module).ValueOrDie()); // The computation root instruction was fused. Verify the fusion instruction // is now the root. @@ -122,20 +122,19 @@ TEST_F(CpuFusionTest, FuseElementwiseOpChain) { EXPECT_EQ(8, fusion_instruction->fused_instruction_count()); // Compile and execute the computation. - auto result = ExecuteAndTransfer(std::move(module), {}); + auto result = ExecuteAndTransfer(module->Clone(), {}); // Check the output correctness. - LiteralTestUtil::ExpectR1Near({14.0, 40.0, 40.0}, *result, - error_spec_); + LiteralTestUtil::ExpectR1Near({14.0, 40.0, 40.0}, result, error_spec_); } -TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusableInstruction) { - // Test a chain of fusable ops with a non-fusable op (a reduce) thrown in the +TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusibleInstruction) { + // Test a chain of fusible ops with a non-fusible op (a reduce) thrown in the // middle. auto module = CreateNewModule(); auto builder = HloComputation::Builder(TestName()); auto input_literal = LiteralUtil::CreateR1({-1.5, -2.5, -3.0}); - Shape vshape = input_literal->shape(); + Shape vshape = input_literal.shape(); auto input = builder.AddInstruction( HloInstruction::CreateConstant(std::move(input_literal))); @@ -184,7 +183,7 @@ TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusableInstruction) { module->AddEntryComputation(builder.Build()); CpuInstructionFusion fusion; - EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie()); + EXPECT_TRUE(fusion.Run(module).ValueOrDie()); // The computation root instruction was fused. Verify the fusion instruction // is now the root. @@ -209,11 +208,11 @@ TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusableInstruction) { << fusion_instruction2->fused_instructions_computation()->ToString(); // Compile and execute the computation. - auto result = ExecuteAndTransfer(std::move(module), {}); + auto result = ExecuteAndTransfer(module->Clone(), {}); // Check the output correctness. LiteralTestUtil::ExpectR1Near({14.0, 40.0, 40.0, 14.0, 40.0, 40.0}, - *result, error_spec_); + result, error_spec_); } TEST_F(CpuFusionTest, TestOperandOrderToAvoidDuplication) { @@ -232,7 +231,7 @@ TEST_F(CpuFusionTest, TestOperandOrderToAvoidDuplication) { // each fusion instruction to ensure that negate is not duplicated. auto builder = HloComputation::Builder(TestName()); auto input_literal = LiteralUtil::CreateR1({1.0, 2.0, 3.0}); - Shape vshape = input_literal->shape(); + Shape vshape = input_literal.shape(); auto constant = builder.AddInstruction( HloInstruction::CreateConstant(std::move(input_literal))); @@ -256,7 +255,7 @@ TEST_F(CpuFusionTest, TestOperandOrderToAvoidDuplication) { // Run fusion. CpuInstructionFusion fusion; - EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie()); + EXPECT_TRUE(fusion.Run(module).ValueOrDie()); auto fusion1 = result->operand(0); auto fusion2 = result->operand(1); @@ -315,7 +314,7 @@ TEST_F(CpuFusionTest, DoNotDuplicateExpensiveOps) { module->AddEntryComputation(builder.Build()); CpuInstructionFusion fusion; - EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie()); + EXPECT_TRUE(fusion.Run(module).ValueOrDie()); // The only fusion instruction should be operand 0 of the tuple (formerly // negate1). diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc index c35569c6619ba5b534c5d8bb7ad683d84b6ecf4b..5cc6d01c0f15d4209cbc1fb259a0078fb9957f6e 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc @@ -58,52 +58,52 @@ class InfeedTest : public ClientLibraryTestBase { }; TEST_F(InfeedTest, SingleInfeedR0Bool) { - TestInfeedRoundTrip(*LiteralUtil::CreateR0(true)); + TestInfeedRoundTrip(LiteralUtil::CreateR0(true)); } TEST_F(InfeedTest, SingleInfeedR1U32) { - TestInfeedRoundTrip(*LiteralUtil::CreateR1({1, 2, 3})); + TestInfeedRoundTrip(LiteralUtil::CreateR1({1, 2, 3})); } TEST_F(InfeedTest, SingleInfeedR2F32) { - TestInfeedRoundTrip(*LiteralUtil::CreateR2F32Linspace(0.0, 1.0, 128, 64)); + TestInfeedRoundTrip(LiteralUtil::CreateR2F32Linspace(0.0, 1.0, 128, 64)); } TEST_F(InfeedTest, SingleInfeedR3F32) { TestInfeedRoundTrip( - *LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, - {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}})); + LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, + {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}})); } TEST_F(InfeedTest, SingleInfeedR3F32DifferentLayout) { const Layout r3_dim0minor = LayoutUtil::MakeLayout({0, 1, 2}); const Layout r3_dim0major = LayoutUtil::MakeLayout({2, 1, 0}); - TestInfeedRoundTrip(*LiteralUtil::CreateR3WithLayout( + TestInfeedRoundTrip(LiteralUtil::CreateR3WithLayout( {{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}, r3_dim0minor)); - TestInfeedRoundTrip(*LiteralUtil::CreateR3WithLayout( + TestInfeedRoundTrip(LiteralUtil::CreateR3WithLayout( {{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}, r3_dim0major)); } TEST_F(InfeedTest, SingleInfeedR4S32) { - TestInfeedRoundTrip(*LiteralUtil::CreateR4( + TestInfeedRoundTrip(LiteralUtil::CreateR4( {{{{1, -2}, {-4, 5}, {6, 7}}, {{8, 9}, {10, 11}, {12, 13}}}, {{{10, 3}, {7, -2}, {3, 6}}, {{2, 5}, {-11, 5}, {-2, -5}}}})); } TEST_F(InfeedTest, SingleInfeedTuple) { - TestInfeedRoundTrip( - *LiteralUtil::MakeTuple({LiteralUtil::CreateR1({1, 2, 3}).get(), - LiteralUtil::CreateR0(false).get()})); + TestInfeedRoundTrip(LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR1({1, 2, 3}), + LiteralUtil::CreateR0(false)})); } TEST_F(InfeedTest, SingleInfeedEmptyTuple) { - TestInfeedRoundTrip(*LiteralUtil::MakeTuple({})); + TestInfeedRoundTrip(LiteralUtil::MakeTuple({})); } // Tests Infeed operation used in a while loop, as in the code below. The @@ -157,21 +157,21 @@ TEST_F(InfeedTest, DISABLED_SingleInfeedInWhile) { // Send 5 Infeed data of shape F32[3]. ASSERT_IS_OK( - client_->TransferToInfeed(*LiteralUtil::CreateR1({1, 2, 3}))); + client_->TransferToInfeed(LiteralUtil::CreateR1({1, 2, 3}))); ASSERT_IS_OK( - client_->TransferToInfeed(*LiteralUtil::CreateR1({4, 5, 6}))); + client_->TransferToInfeed(LiteralUtil::CreateR1({4, 5, 6}))); ASSERT_IS_OK( - client_->TransferToInfeed(*LiteralUtil::CreateR1({7, 8, 9}))); + client_->TransferToInfeed(LiteralUtil::CreateR1({7, 8, 9}))); ASSERT_IS_OK( - client_->TransferToInfeed(*LiteralUtil::CreateR1({10, 11, 12}))); + client_->TransferToInfeed(LiteralUtil::CreateR1({10, 11, 12}))); ASSERT_IS_OK( - client_->TransferToInfeed(*LiteralUtil::CreateR1({13, 14, 15}))); + client_->TransferToInfeed(LiteralUtil::CreateR1({13, 14, 15}))); delete computation_thread; // Joins the thread. auto result_literal = client_->Transfer(*result).ConsumeValueOrDie(); // Only the first 3 infeed data should be added. - LiteralTestUtil::ExpectR0Near(45.0f, *result_literal, ErrorSpec{1e-7}); + LiteralTestUtil::ExpectR0Near(45.0f, result_literal, ErrorSpec{1e-7}); } // Tests two Infeed operations with a total order. The order is enforced by @@ -250,17 +250,17 @@ TEST_F(InfeedTest, DISABLED_TwoInfeedsInTotalOrder) { // Send the first 4 Infeed data of shape Tuple(F32[2], PRED). ASSERT_IS_OK(client_->TransferToInfeed( - *LiteralUtil::MakeTuple({LiteralUtil::CreateR1({1, 2}).get(), - LiteralUtil::CreateR0(true).get()}))); + LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1({1, 2}), + LiteralUtil::CreateR0(true)}))); ASSERT_IS_OK(client_->TransferToInfeed( - *LiteralUtil::MakeTuple({LiteralUtil::CreateR1({3, 4}).get(), - LiteralUtil::CreateR0(true).get()}))); + LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1({3, 4}), + LiteralUtil::CreateR0(true)}))); ASSERT_IS_OK(client_->TransferToInfeed( - *LiteralUtil::MakeTuple({LiteralUtil::CreateR1({5, 6}).get(), - LiteralUtil::CreateR0(true).get()}))); + LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1({5, 6}), + LiteralUtil::CreateR0(true)}))); ASSERT_IS_OK(client_->TransferToInfeed( - *LiteralUtil::MakeTuple({LiteralUtil::CreateR1({7, 8}).get(), - LiteralUtil::CreateR0(false).get()}))); + LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1({7, 8}), + LiteralUtil::CreateR0(false)}))); // Asynchronously launch the execution on the device. std::unique_ptr result; @@ -275,21 +275,21 @@ TEST_F(InfeedTest, DISABLED_TwoInfeedsInTotalOrder) { // Infeed data, and send the rest Infeed data of shape Tuple(F32[3], PRED). sleep(1); ASSERT_IS_OK(client_->TransferToInfeed( - *LiteralUtil::MakeTuple({LiteralUtil::CreateR1({1, 2, 3}).get(), - LiteralUtil::CreateR0(true).get()}))); + LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1({1, 2, 3}), + LiteralUtil::CreateR0(true)}))); ASSERT_IS_OK(client_->TransferToInfeed( - *LiteralUtil::MakeTuple({LiteralUtil::CreateR1({7, 8, 9}).get(), - LiteralUtil::CreateR0(false).get()}))); + LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1({7, 8, 9}), + LiteralUtil::CreateR0(false)}))); ASSERT_IS_OK(client_->TransferToInfeed( - *LiteralUtil::MakeTuple({LiteralUtil::CreateR1({4, 5, 6}).get(), - LiteralUtil::CreateR0(true).get()}))); + LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1({4, 5, 6}), + LiteralUtil::CreateR0(true)}))); // Wait for the execution to be done, and transfer the result. delete computation_thread; // Joins the thread. auto result_literal = client_->Transfer(*result).ConsumeValueOrDie(); // Only the first 6 infeed data should be added. - LiteralTestUtil::ExpectR0Near(66.0f, *result_literal, ErrorSpec{1e-7}); + LiteralTestUtil::ExpectR0Near(66.0f, result_literal, ErrorSpec{1e-7}); } } // namespace diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_intrinsic_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_intrinsic_test.cc index 9457e57d7bb31a56b7a96efdbc52f65988866129..a434c04a980b9b3cd849792b97a0d9e965ba09f2 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_intrinsic_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_intrinsic_test.cc @@ -65,8 +65,8 @@ class CpuUnaryIntrinsicTest features = ""; } - return absl::StrCat(opcode.c_str(), "_On_", triple.c_str(), - features.empty() ? "" : "_With", features.c_str()); + return absl::StrCat(opcode, "_On_", triple, + (features.empty() ? "" : "_With"), features); } }; diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_key_value_sort_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_key_value_sort_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..3934c03a04c978009282b3cd0d39bacf9b12a356 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_key_value_sort_test.cc @@ -0,0 +1,54 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h" +#include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" + +namespace xla { +namespace cpu { +namespace { +class CpuKeyValueSortTest : public CpuCodegenTest {}; + +TEST_F(CpuKeyValueSortTest, SortR1) { + const string hlo_text = R"( +HloModule KeyValueSort + +ENTRY main { + a = f32[10] parameter(0) + + ROOT result = f32[10] sort(f32[10] a), dimensions={0} +} +)"; + + string filecheck_pattern = R"( +CHECK: call void @__xla_cpu_runtime_KeyValueSort +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_text)); + + CpuAotCompilationOptions options{ + /*triple=*/"x86_64", /*cpu_name=*/"", /*features=*/"", + /*entry_point_name=*/"entry", + /*relocation_model=*/CpuAotCompilationOptions::RelocationModel::Static}; + + CompileAheadOfTimeAndVerifyIr(std::move(module), options, filecheck_pattern, + /*match_optimized_ir=*/true); +} + +} // namespace +} // namespace cpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc index bb105194f1c9001ca4d9fff9174e1ea7e5d8b72a..b35fd9dad877c319c3d0110c96a00aeefa78769e 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc @@ -41,8 +41,7 @@ class CpuNoAliasTest : public CpuCodegenTest {}; TEST_F(CpuNoAliasTest, Concat) { HloComputation::Builder builder(TestName()); - std::unique_ptr literal = - LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + Literal literal = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); auto param_shape = ShapeUtil::MakeShape(F32, {2, 2}); HloInstruction* param_x = builder.AddInstruction( HloInstruction::CreateParameter(0, param_shape, "x")); @@ -122,7 +121,7 @@ TEST_F(CpuNoAliasTest, Concat) { CHECK: %read_concat2_array = load {{.*}} !alias.scope [[concat1_noalias]], !noalias [[concat1_scope]] CHECK-DAG: [[buf_size32:![0-9]+]] = !{!"buffer:{{.*}} size:32 CHECK-DAG: [[buf_size48:![0-9]+]] = !{!"buffer:{{.*}} size:48 - CHECK-DAG: [[param_x_noalias]] = !{[[buf_size32]], [[buf_size48]]} + CHECK-DAG: [[param_x_noalias]] = !{[[buf_size48]], [[buf_size32]]} CHECK-DAG: [[concat1_scope]] = !{[[buf_size32]]} CHECK-DAG: [[concat1_noalias]] = !{[[buf_size48]]} )"; diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc index 780c07f819ea2f94ed2f27dc0be0983f0389bfbc..e2c7af541eede5265f274c72f55305549f059839 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc @@ -54,6 +54,33 @@ CHECK: private constant [48 x i8] /*match_optimized_ir=*/false); } +TEST_F(CpuOutfeedTest, OutfeedTokenInTuple) { + const string hlo_text = R"( +HloModule OutfeedTokenInTuple + +ENTRY main { + const = f32[] constant(42) + epoch = token[] after-all() + outfeed.tok = token[] outfeed(const, epoch) + ROOT root = (token[], f32[]) tuple(outfeed.tok, const) +} +)"; + + string filecheck_pattern = R"( +CHECK: Outfeed +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_text)); + + CpuAotCompilationOptions options{ + /*triple=*/"x86_64-pc-linux", /*cpu_name=*/"", /*features=*/"", + /*entry_point_name=*/"entry", + /*relocation_model=*/CpuAotCompilationOptions::RelocationModel::Static}; + + CompileAheadOfTimeAndVerifyIr(std::move(module), options, filecheck_pattern, + /*match_optimized_ir=*/false); +} } // namespace } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/vector_support_library.cc b/tensorflow/compiler/xla/service/cpu/vector_support_library.cc index 962ea69c09487735a7d5e3309dfbf2969655da81..1bd4b59dd604687589eee061d34aa9ca94f6d700 100644 --- a/tensorflow/compiler/xla/service/cpu/vector_support_library.cc +++ b/tensorflow/compiler/xla/service/cpu/vector_support_library.cc @@ -428,7 +428,7 @@ std::vector TileVariable::Get() const { return result; } -void TileVariable::Set(tensorflow::gtl::ArraySlice value) { +void TileVariable::Set(absl::Span value) { CHECK_EQ(value.size(), storage_.size()); for (int64 i = 0, e = value.size(); i < e; i++) { storage_[i].Set(value[i]); diff --git a/tensorflow/compiler/xla/service/cpu/vector_support_library.h b/tensorflow/compiler/xla/service/cpu/vector_support_library.h index c728f6df0aef83e6ddc6c932a347f14da06d9d0d..5690d2be2fe3e21c96b51a5226e0b29148217fd1 100644 --- a/tensorflow/compiler/xla/service/cpu/vector_support_library.h +++ b/tensorflow/compiler/xla/service/cpu/vector_support_library.h @@ -18,12 +18,12 @@ limitations under the License. #include +#include "absl/types/span.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Value.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" namespace xla { namespace cpu { @@ -324,7 +324,7 @@ class TileVariable { std::vector initial_value); std::vector Get() const; - void Set(tensorflow::gtl::ArraySlice value); + void Set(absl::Span value); private: std::vector storage_; diff --git a/tensorflow/compiler/xla/service/cpu/xfeed_manager.cc b/tensorflow/compiler/xla/service/cpu/xfeed_manager.cc index 47543b2082f55cf7b8cf60f1c5bbb16a0a609912..b9e47f5aade3334bece28643e6e32ecfce3bf67b 100644 --- a/tensorflow/compiler/xla/service/cpu/xfeed_manager.cc +++ b/tensorflow/compiler/xla/service/cpu/xfeed_manager.cc @@ -37,7 +37,7 @@ void XfeedQueueManager::Reset() { } void XfeedQueueManager::EnqueueBuffersAtomically( - tensorflow::gtl::ArraySlice buffers) { + absl::Span buffers) { tensorflow::mutex_lock l(mu_); bool was_empty = enqueued_buffers_.empty(); for (XfeedBuffer* b : buffers) { diff --git a/tensorflow/compiler/xla/service/cpu/xfeed_manager.h b/tensorflow/compiler/xla/service/cpu/xfeed_manager.h index b4ace232607e14fbfec01d48946f0031d96cd027..990ff94ba2338cb663b655ca3106bda83ab718a3 100644 --- a/tensorflow/compiler/xla/service/cpu/xfeed_manager.h +++ b/tensorflow/compiler/xla/service/cpu/xfeed_manager.h @@ -22,10 +22,10 @@ limitations under the License. #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/mutex.h" namespace xla { @@ -63,8 +63,7 @@ class XfeedQueueManager { // called when the buffer will no longer be accessed by the XfeedManager, // either as a result of a call to Reset or because the runtime has dequeued // and used the buffer. - void EnqueueBuffersAtomically( - tensorflow::gtl::ArraySlice buffers); + void EnqueueBuffersAtomically(absl::Span buffers); // Blocks until the queue is non-empty, then returns the buffer at the head of // the queue. Sets the current buffer to be the returned buffer. It is an diff --git a/tensorflow/compiler/xla/service/cpu/xfeed_manager_test.cc b/tensorflow/compiler/xla/service/cpu/xfeed_manager_test.cc index 8fe65f488a2f0c4031926fa4c5f02dcf5473568d..cc38b81455b5a35cdcd07ac1dfb80cc7b101a7bc 100644 --- a/tensorflow/compiler/xla/service/cpu/xfeed_manager_test.cc +++ b/tensorflow/compiler/xla/service/cpu/xfeed_manager_test.cc @@ -66,9 +66,9 @@ void ProcessNextBuffer(int32 length) { auto shape = ShapeUtil::MakeShape(U8, {length}); string bytes = shape.SerializeAsString(); void* buffer = __xla_cpu_runtime_AcquireInfeedBufferForDequeue( - length, bytes.data(), bytes.size()); - __xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue(length, buffer, - bytes.data(), bytes.size()); + /*run_options=*/nullptr, length, bytes.data(), bytes.size()); + __xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue( + /*run_options=*/nullptr, length, buffer, bytes.data(), bytes.size()); } // Performs the acquire/release sequence on the outfeed, as the generated CPU @@ -76,16 +76,16 @@ void ProcessNextBuffer(int32 length) { void ProcessNextOutfeedBuffer(int32 length, const Shape& shape) { string bytes = shape.SerializeAsString(); void* buffer = __xla_cpu_runtime_AcquireOutfeedBufferForPopulation( - length, bytes.data(), bytes.size()); + /*run_options=*/nullptr, length, bytes.data(), bytes.size()); __xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation( - length, buffer, bytes.data(), bytes.size()); + /*run_options=*/nullptr, length, buffer, bytes.data(), bytes.size()); } TEST_F(InfeedManagerTest, SingleThreadedSequential) { TestInfeedBuffer* a = new TestInfeedBuffer(64); TestInfeedBuffer* b = new TestInfeedBuffer(32); - cpu::runtime::XfeedManager* xfeed = cpu::runtime::GetXfeedManager(); + cpu::runtime::XfeedManager* xfeed = cpu::runtime::GetXfeedManager(0); xfeed->infeed()->EnqueueBuffersAtomically({a}); xfeed->infeed()->EnqueueBuffersAtomically({b}); @@ -97,7 +97,7 @@ TEST_F(InfeedManagerTest, SingleThreadedInterleaved) { TestInfeedBuffer* a = new TestInfeedBuffer(64); TestInfeedBuffer* b = new TestInfeedBuffer(32); - cpu::runtime::XfeedManager* xfeed = cpu::runtime::GetXfeedManager(); + cpu::runtime::XfeedManager* xfeed = cpu::runtime::GetXfeedManager(0); xfeed->infeed()->EnqueueBuffersAtomically({a}); ProcessNextBuffer(a->length()); @@ -108,7 +108,7 @@ TEST_F(InfeedManagerTest, SingleThreadedInterleaved) { TEST_F(InfeedManagerTest, MultiThreaded) { tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), "test", 2); - cpu::runtime::XfeedManager* xfeed = cpu::runtime::GetXfeedManager(); + cpu::runtime::XfeedManager* xfeed = cpu::runtime::GetXfeedManager(0); const int32 length = 64; @@ -130,7 +130,7 @@ TEST_F(InfeedManagerTest, MultiThreaded) { TEST_F(InfeedManagerTest, OutfeedWrongShape) { TestInfeedBuffer* b = new TestInfeedBuffer(32, /*expect_shape_match=*/false); - cpu::runtime::XfeedManager* xfeed = cpu::runtime::GetXfeedManager(); + cpu::runtime::XfeedManager* xfeed = cpu::runtime::GetXfeedManager(0); xfeed->outfeed()->EnqueueBuffersAtomically({b}); ProcessNextOutfeedBuffer(32, ShapeUtil::MakeShape(U8, {33})); diff --git a/tensorflow/compiler/xla/service/defuser.cc b/tensorflow/compiler/xla/service/defuser.cc index d124f74d19d83269be96ee34a6b4b2a8d00a978f..661539cccb4ef27a49a73f97a0a8b0d9dfc77061 100644 --- a/tensorflow/compiler/xla/service/defuser.cc +++ b/tensorflow/compiler/xla/service/defuser.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" #include "tensorflow/compiler/xla/service/call_graph.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -48,7 +49,7 @@ Status Defuse(HloInstruction* fusion_instruction) { fusion_instruction->fused_instructions_computation(); // A map from fused instruction to its defused clone. - tensorflow::gtl::FlatMap + absl::flat_hash_map defused_instructions; // Initialize map to contain the fusion instruction parameters mapping // to the operands of the fusion instruction. diff --git a/tensorflow/compiler/xla/service/defuser.h b/tensorflow/compiler/xla/service/defuser.h index c326beb899f9a434d772c0fda032efc9113b6f42..aaa41fc4fe779cdf01a34e86855cac02552ee383 100644 --- a/tensorflow/compiler/xla/service/defuser.h +++ b/tensorflow/compiler/xla/service/defuser.h @@ -25,7 +25,7 @@ namespace xla { // A pass which replaces all fusion instructions with the equivalent un-fused // instructions. -class Defuser : public HloPassInterface { +class Defuser : public HloModulePass { public: Defuser() {} ~Defuser() override {} diff --git a/tensorflow/compiler/xla/service/defuser_test.cc b/tensorflow/compiler/xla/service/defuser_test.cc index 37d1895d41447ba0219bb57170e61154fdd8bcdd..e727ba49cb6321e499b5d50d5f45e7f7f6bb6fef 100644 --- a/tensorflow/compiler/xla/service/defuser_test.cc +++ b/tensorflow/compiler/xla/service/defuser_test.cc @@ -26,11 +26,6 @@ namespace xla { namespace { class DefuserTest : public HloVerifiedTestBase { - public: - DefuserTest() - : HloVerifiedTestBase(/*layout_sensitive=*/false, - /*allow_mixed_precision=*/false) {} - protected: // Returns the number of fusion instructions in the module. int FusionCount() { diff --git a/tensorflow/compiler/xla/service/despecializer.cc b/tensorflow/compiler/xla/service/despecializer.cc index ba2a674d9af547ad574ae49e1e87f3afcaf6112a..b3549acfc291a54b2345b006310613c3a45a4b47 100644 --- a/tensorflow/compiler/xla/service/despecializer.cc +++ b/tensorflow/compiler/xla/service/despecializer.cc @@ -24,7 +24,7 @@ namespace xla { namespace { // Pass which strips control dependencies from all instructions in the module. -class ControlDepRemover : public HloPassInterface { +class ControlDepRemover : public HloModulePass { public: ControlDepRemover() = default; absl::string_view name() const override { return "control-dep-remover"; } diff --git a/tensorflow/compiler/xla/service/despecializer.h b/tensorflow/compiler/xla/service/despecializer.h index 7be70add2f7566376b3179740e411d6341badf7c..46dcc3a438cbdf3ff1b3c99fa15b35ee7a4e280e 100644 --- a/tensorflow/compiler/xla/service/despecializer.h +++ b/tensorflow/compiler/xla/service/despecializer.h @@ -30,7 +30,7 @@ namespace xla { // // Current despecialization passes are Defuser, ImplicitBroadcastRemover, // and BFloat16MixedPrecisionRemoval. -class Despecializer : public HloPassInterface { +class Despecializer : public HloModulePass { public: Despecializer(); absl::string_view name() const override { return "despecializer"; } diff --git a/tensorflow/compiler/xla/service/device_memory_allocator.cc b/tensorflow/compiler/xla/service/device_memory_allocator.cc index e228bb56bce8febcca28ae171f6de90973d020ab..edbcb25247421cdb50a845df1ec8b1851970efe3 100644 --- a/tensorflow/compiler/xla/service/device_memory_allocator.cc +++ b/tensorflow/compiler/xla/service/device_memory_allocator.cc @@ -25,7 +25,7 @@ namespace xla { StreamExecutorMemoryAllocator::StreamExecutorMemoryAllocator( const se::Platform* platform, - tensorflow::gtl::ArraySlice stream_executors) + absl::Span stream_executors) : DeviceMemoryAllocator(platform), stream_executors_(stream_executors.begin(), stream_executors.end()) {} @@ -36,9 +36,8 @@ StatusOr StreamExecutorMemoryAllocator::Allocate( se::DeviceMemoryBase result = stream_executor->AllocateArray(size); if (size > 0 && result == nullptr) { return ResourceExhausted( - "Failed to allocate request for %s (%lluB) on device ordinal %d", - tensorflow::strings::HumanReadableNumBytes(size).c_str(), size, - device_ordinal); + "Failed to allocate request for %s (%uB) on device ordinal %d", + tensorflow::strings::HumanReadableNumBytes(size), size, device_ordinal); } return OwningDeviceMemory(result, device_ordinal, this); } @@ -61,12 +60,12 @@ StatusOr StreamExecutorMemoryAllocator::GetStreamExecutor( } if (device_ordinal >= stream_executors_.size()) { return InvalidArgument( - "device ordinal value (%d) >= number of devices (%zu)", device_ordinal, + "device ordinal value (%d) >= number of devices (%u)", device_ordinal, stream_executors_.size()); } if (stream_executors_[device_ordinal] == nullptr) { return NotFound("Device %s:%d present but not supported", - platform()->Name().c_str(), device_ordinal); + platform()->Name(), device_ordinal); } return stream_executors_[device_ordinal]; } diff --git a/tensorflow/compiler/xla/service/device_memory_allocator.h b/tensorflow/compiler/xla/service/device_memory_allocator.h index d87b86caf0d3acaa5bf9a455cff2315cedb2496d..a2308ee7a4137bbafe9804c30e33cc68d4628588 100644 --- a/tensorflow/compiler/xla/service/device_memory_allocator.h +++ b/tensorflow/compiler/xla/service/device_memory_allocator.h @@ -18,10 +18,10 @@ limitations under the License. #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/service/owning_device_memory.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/types.h" @@ -80,7 +80,7 @@ class StreamExecutorMemoryAllocator : public DeviceMemoryAllocator { public: StreamExecutorMemoryAllocator( const se::Platform* platform, - tensorflow::gtl::ArraySlice stream_executors); + absl::Span stream_executors); StatusOr Allocate(int device_ordinal, uint64 size, bool retry_on_failure) override; diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.cc b/tensorflow/compiler/xla/service/dfs_hlo_visitor.cc index 2172ae0a29626660e8abd29a789e0baa3831519d..c54f81e6915a286757e59821c2684a7271889816 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.cc +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.cc @@ -28,14 +28,14 @@ template Status DfsHloVisitorBase::HandleElementwiseUnary( HloInstructionPtr hlo) { return Unimplemented("DfsHloVisitor::HandleElementwiseUnary: %s", - HloOpcodeString(hlo->opcode()).c_str()); + HloOpcodeString(hlo->opcode())); } template Status DfsHloVisitorBase::HandleElementwiseBinary( HloInstructionPtr hlo) { return Unimplemented("DfsHloVisitor::HandleElementwiseBinary: %s", - HloOpcodeString(hlo->opcode()).c_str()); + HloOpcodeString(hlo->opcode())); } template @@ -50,7 +50,7 @@ void DfsHloVisitorBase::SetVisiting( const HloInstruction& instruction) { VLOG(3) << "marking HLO " << &instruction << " as visiting: "; DCHECK(NotVisited(instruction)); - visit_state_.SetState(instruction.unique_id(), VisitState::kVisiting); + visit_state_[instruction.unique_id()] = VisitState::kVisiting; } template @@ -58,7 +58,7 @@ void DfsHloVisitorBase::SetVisited( const HloInstruction& instruction) { VLOG(3) << "marking HLO " << &instruction << " as visited: "; DCHECK(NotVisited(instruction) || IsVisiting(instruction)); - visit_state_.SetState(instruction.unique_id(), VisitState::kVisited); + visit_state_[instruction.unique_id()] = VisitState::kVisited; } template diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h index 275e6cc61d84b77312ba3d786c557cbb9f8c3a38..4159aa281fa2b66d310d7c135f123a5a3bb83270 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h @@ -19,15 +19,15 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" #include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" @@ -107,6 +107,7 @@ class DfsHloVisitorBase { virtual Status HandleFft(HloInstructionPtr fft) = 0; virtual Status HandleCrossReplicaSum(HloInstructionPtr hlo) = 0; virtual Status HandleAllToAll(HloInstructionPtr hlo) = 0; + virtual Status HandleCollectivePermute(HloInstructionPtr hlo) = 0; virtual Status HandleCompare(HloInstructionPtr hlo) { return HandleElementwiseBinary(hlo); } @@ -263,21 +264,25 @@ class DfsHloVisitorBase { kVisited = 2, }; - VisitState GetVisitState(int id) { return visit_state_.GetState(id); } + VisitState GetVisitState(int id) { + auto iter = visit_state_.find(id); + if (iter == visit_state_.end()) { + return VisitState::kNotVisited; + } + return iter->second; + } VisitState GetVisitState(const HloInstruction& instruction); // Resize internal state if necessary to hold state for ids <= num. // This call is purely a performance hint and can be omitted without // affecting correctness. - void ReserveVisitStates(int num) { visit_state_.Reserve(num); } + void ReserveVisitStates(int num) { visit_state_.reserve(num); } // Useful when we want to visit the same computation more than once with the // same visitor. - void ResetVisitStates() { visit_state_.Reset(); } + void ResetVisitStates() { visit_state_.clear(); } - void SetVisitState(int id, VisitState state) { - visit_state_.SetState(id, state); - } + void SetVisitState(int id, VisitState state) { visit_state_[id] = state; } // Sets the visitation state of the given instruction as kVisiting. // @@ -326,44 +331,7 @@ class DfsHloVisitorBase { virtual Status Postprocess(HloInstructionPtr hlo); private: - class DFSVisitStates { - public: - DFSVisitStates() {} - void Reserve(uint64 num) { - states_.reserve((num + kStatesPerWord - 1) / kStatesPerWord); - } - VisitState GetState(uint64 id) { - uint64 word_index = id / kStatesPerWord; - if (word_index >= states_.size()) { - return VisitState::kNotVisited; - } - static_assert(static_cast(VisitState::kVisited) < 3, - "VisitState must fit in two bits"); - uint64 w = states_[word_index]; - uint32 shift = 2 * (id % kStatesPerWord); // 2 bits per state - return static_cast((w >> shift) & 0x3); - } - void SetState(uint64 id, VisitState state) { - uint64 word_index = id / kStatesPerWord; - if (word_index >= states_.size()) { - states_.resize(word_index + 1, 0); - } - uint64* w = &states_[word_index]; - uint32 shift = 2 * (id % kStatesPerWord); // 2 bits per state - uint64 mask = 0x3ull << shift; - *w = (*w & ~mask) | (static_cast(state) << shift); - DCHECK_EQ(GetState(id), state); - } - void Reset() { states_.clear(); } - - private: - static const uint32 kStatesPerWord = sizeof(uint64) / 2 /*bits per entry*/; - // Map from id to two-bit states. We store 32 such states per 64-bit - // value - std::vector states_; - }; - - DFSVisitStates visit_state_; + absl::flat_hash_map visit_state_; TF_DISALLOW_COPY_AND_ASSIGN(DfsHloVisitorBase); }; diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h index 6ec4893f7ae90eda8bb729c384881b9d11df90e2..4cd10ab06cd3b804406607212d3f3c316d6cff95 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h @@ -17,13 +17,13 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_WITH_DEFAULT_H_ #include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" @@ -94,8 +94,11 @@ class DfsHloVisitorWithDefaultBase Status HandleCrossReplicaSum(HloInstructionPtr crs) override { return DefaultAction(crs); } - Status HandleAllToAll(HloInstructionPtr crs) override { - return DefaultAction(crs); + Status HandleAllToAll(HloInstructionPtr hlo) override { + return DefaultAction(hlo); + } + Status HandleCollectivePermute(HloInstructionPtr hlo) override { + return DefaultAction(hlo); } Status HandleRng(HloInstructionPtr random) override { return DefaultAction(random); diff --git a/tensorflow/compiler/xla/service/dot_decomposer.cc b/tensorflow/compiler/xla/service/dot_decomposer.cc index 09cb10d6ee579111b6e0cdb460b9af2b95d090db..b2ba2617902104bfea06713332fa1c2aedea536d 100644 --- a/tensorflow/compiler/xla/service/dot_decomposer.cc +++ b/tensorflow/compiler/xla/service/dot_decomposer.cc @@ -134,9 +134,9 @@ Status DecomposeBatchDot(HloInstruction* dot) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - auto dot_r2 = computation->AddInstruction(HloInstruction::CreateDot( - dot_shape_r2, lhs_slice_r2, rhs_slice_r2, dot_dnums)); - dot_r2->set_precision_config(dot->precision_config()); + auto dot_r2 = computation->AddInstruction( + HloInstruction::CreateDot(dot_shape_r2, lhs_slice_r2, rhs_slice_r2, + dot_dnums, dot->precision_config())); // Reshape Dot to R3 so we can concat along batch dimension. auto dot_r3 = computation->AddInstruction( diff --git a/tensorflow/compiler/xla/service/dot_decomposer.h b/tensorflow/compiler/xla/service/dot_decomposer.h index fc38e317001695921d20f9bbe5775e61a8eeaa45..40e7a3b4c25ff20674de0cce3fe2907fc43a5cb9 100644 --- a/tensorflow/compiler/xla/service/dot_decomposer.h +++ b/tensorflow/compiler/xla/service/dot_decomposer.h @@ -23,7 +23,7 @@ namespace xla { // DotDecomposer is a pass which decomposes batch Dot operations into a // sequence of smaller (R2) Dot operations. -class DotDecomposer : public HloPassInterface { +class DotDecomposer : public HloModulePass { public: // Decomposes batch Dot operations when 'decompose_batch_dot' is true. DotDecomposer(bool decompose_batch_dot = true) diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index 26af67cc1c78d6ffc93b62a66e0f60a8bdec611c..515267edd7caf42e04ebe638b99006db8967ea30 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -28,6 +28,8 @@ limitations under the License. #include "llvm/IR/Intrinsics.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "tensorflow/compiler/xla/primitive_util.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" @@ -204,7 +206,7 @@ llvm::Value* EmitIntegralToFloating(llvm::Value* integer_value, } // namespace StatusOr ElementalIrEmitter::EmitUnaryOp( - const HloInstruction* op, llvm::Value* operand_value) const { + const HloInstruction* op, llvm::Value* operand_value) { if (op->opcode() == HloOpcode::kCopy) { return operand_value; } else if (ShapeUtil::ElementIsIntegral(op->operand(0)->shape()) || @@ -218,7 +220,7 @@ StatusOr ElementalIrEmitter::EmitUnaryOp( } StatusOr ElementalIrEmitter::EmitIntegerUnaryOp( - const HloInstruction* op, llvm::Value* operand_value) const { + const HloInstruction* op, llvm::Value* operand_value) { switch (op->opcode()) { case HloOpcode::kConvert: { PrimitiveType from_type = op->operand(0)->shape().element_type(); @@ -230,14 +232,14 @@ StatusOr ElementalIrEmitter::EmitIntegerUnaryOp( } if (to_type == PRED) { return b_->CreateZExt( - b_->CreateICmpNE(operand_value, llvm::ConstantInt::get( - operand_value->getType(), 0)), + ICmpNE(operand_value, + llvm::ConstantInt::get(operand_value->getType(), 0)), llvm_ir::PrimitiveTypeToIrType(PRED, module_)); } if (primitive_util::IsIntegralType(to_type)) { - return b_->CreateIntCast( - operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_), - primitive_util::IsSignedIntegralType(from_type)); + return IntCast(operand_value, + llvm_ir::PrimitiveTypeToIrType(to_type, module_), + primitive_util::IsSignedIntegralType(from_type)); } if (primitive_util::IsFloatingPointType(to_type)) { if (to_type == BF16) { @@ -253,19 +255,17 @@ StatusOr ElementalIrEmitter::EmitIntegerUnaryOp( primitive_util::ComplexComponentType(to_type), module_); if (primitive_util::IsSignedIntegralType(from_type)) { return EmitComposeComplex( - op, b_->CreateSIToFP(operand_value, to_ir_component_type), - nullptr); + op, SIToFP(operand_value, to_ir_component_type), nullptr); } if (primitive_util::IsUnsignedIntegralType(from_type) || from_type == PRED) { return EmitComposeComplex( - op, b_->CreateUIToFP(operand_value, to_ir_component_type), - nullptr); + op, UIToFP(operand_value, to_ir_component_type), nullptr); } } return Unimplemented("conversion from primitive type %s to %s", - PrimitiveType_Name(from_type).c_str(), - PrimitiveType_Name(to_type).c_str()); + PrimitiveType_Name(from_type), + PrimitiveType_Name(to_type)); } case HloOpcode::kBitcastConvert: { PrimitiveType from_type = op->operand(0)->shape().element_type(); @@ -276,14 +276,13 @@ StatusOr ElementalIrEmitter::EmitIntegerUnaryOp( } if (primitive_util::BitWidth(from_type) == primitive_util::BitWidth(to_type)) { - return b_->CreateBitCast( - operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_)); + return BitCast(operand_value, + llvm_ir::PrimitiveTypeToIrType(to_type, module_)); } return InvalidArgument( "bitcast conversion from primitive type %s to %s with unequal " "bit-widths (%u versus %u) ", - PrimitiveType_Name(from_type).c_str(), - PrimitiveType_Name(to_type).c_str(), + PrimitiveType_Name(from_type), PrimitiveType_Name(to_type), primitive_util::BitWidth(from_type), primitive_util::BitWidth(to_type)); } @@ -293,8 +292,8 @@ StatusOr ElementalIrEmitter::EmitIntegerUnaryOp( if (is_signed) { auto type = llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(), module_); - auto cmp = b_->CreateICmpSGE(operand_value, GetZero(type)); - return Select(cmp, operand_value, b_->CreateNeg(operand_value)); + auto cmp = ICmpSGE(operand_value, GetZero(type)); + return Select(cmp, operand_value, Neg(operand_value)); } else { return operand_value; } @@ -310,34 +309,33 @@ StatusOr ElementalIrEmitter::EmitIntegerUnaryOp( << op->shape().element_type(); auto type = llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(), module_); - auto cmp = b_->CreateICmpEQ(operand_value, GetZero(type)); - auto ashr = b_->CreateAShr(operand_value, type->getIntegerBitWidth() - 1); - return Select(cmp, GetZero(type), b_->CreateOr(ashr, 1)); + auto cmp = ICmpEQ(operand_value, GetZero(type)); + auto ashr = AShr(operand_value, type->getIntegerBitWidth() - 1); + return Select(cmp, GetZero(type), Or(ashr, 1)); } case HloOpcode::kNegate: - return b_->CreateNeg(operand_value); + return Neg(operand_value); case HloOpcode::kNot: { auto type = op->shape().element_type(); if (type == PRED) { // It is not sufficient to just call CreateNot() here because a PRED // is represented as an i8 and the truth value is stored only in the // bottom bit. - return b_->CreateZExt( - b_->CreateNot(b_->CreateTrunc(operand_value, b_->getInt1Ty())), - llvm_ir::PrimitiveTypeToIrType(PRED, module_)); + return b_->CreateZExt(Not(Trunc(operand_value, b_->getInt1Ty())), + llvm_ir::PrimitiveTypeToIrType(PRED, module_)); } else if (primitive_util::IsIntegralType(type)) { - return b_->CreateNot(operand_value); + return Not(operand_value); } return Unimplemented("unary op Not is not defined for type '%d'", type); } default: return Unimplemented("unary integer op '%s'", - HloOpcodeString(op->opcode()).c_str()); + HloOpcodeString(op->opcode())); } } StatusOr ElementalIrEmitter::EmitFloatUnaryOp( - const HloInstruction* op, llvm::Value* operand_value) const { + const HloInstruction* op, llvm::Value* operand_value) { switch (op->opcode()) { case HloOpcode::kConvert: { PrimitiveType from_type = op->operand(0)->shape().element_type(); @@ -354,8 +352,8 @@ StatusOr ElementalIrEmitter::EmitFloatUnaryOp( } return EmitComposeComplex( op, - b_->CreateFPCast(operand_value, llvm_ir::PrimitiveTypeToIrType( - to_component_type, module_)), + FPCast(operand_value, + llvm_ir::PrimitiveTypeToIrType(to_component_type, module_)), nullptr); } if (from_type == BF16) { @@ -371,26 +369,25 @@ StatusOr ElementalIrEmitter::EmitFloatUnaryOp( } if (to_type == PRED) { return b_->CreateZExt( - b_->CreateFCmpUNE( - operand_value, - llvm::ConstantFP::get(operand_value->getType(), 0.0)), + FCmpUNE(operand_value, + llvm::ConstantFP::get(operand_value->getType(), 0.0)), llvm_ir::PrimitiveTypeToIrType(PRED, module_)); } if (primitive_util::IsFloatingPointType(to_type)) { - return b_->CreateFPCast( - operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_)); + return FPCast(operand_value, + llvm_ir::PrimitiveTypeToIrType(to_type, module_)); } if (primitive_util::IsSignedIntegralType(to_type)) { - return b_->CreateFPToSI( - operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_)); + return FPToSI(operand_value, + llvm_ir::PrimitiveTypeToIrType(to_type, module_)); } if (primitive_util::IsUnsignedIntegralType(to_type)) { - return b_->CreateFPToUI( - operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_)); + return FPToUI(operand_value, + llvm_ir::PrimitiveTypeToIrType(to_type, module_)); } return Unimplemented("unhandled conversion operation: %s => %s", - PrimitiveType_Name(from_type).c_str(), - PrimitiveType_Name(to_type).c_str()); + PrimitiveType_Name(from_type), + PrimitiveType_Name(to_type)); } case HloOpcode::kBitcastConvert: { PrimitiveType from_type = op->operand(0)->shape().element_type(); @@ -401,14 +398,13 @@ StatusOr ElementalIrEmitter::EmitFloatUnaryOp( } if (primitive_util::BitWidth(from_type) == primitive_util::BitWidth(to_type)) { - return b_->CreateBitCast( - operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_)); + return BitCast(operand_value, + llvm_ir::PrimitiveTypeToIrType(to_type, module_)); } return InvalidArgument( "bitcast conversion from primitive type %s to %s with unequal " "bit-widths (%u versus %u) ", - PrimitiveType_Name(from_type).c_str(), - PrimitiveType_Name(to_type).c_str(), + PrimitiveType_Name(from_type), PrimitiveType_Name(to_type), primitive_util::BitWidth(from_type), primitive_util::BitWidth(to_type)); } @@ -446,8 +442,8 @@ StatusOr ElementalIrEmitter::EmitFloatUnaryOp( // TODO(b/32151903): Ensure consistent sign behavior for -0.0. auto type = operand_value->getType(); auto zero = llvm::ConstantFP::get(type, 0.0); - auto oeq = b_->CreateFCmpOEQ(operand_value, zero); - auto olt = b_->CreateFCmpOLT(operand_value, zero); + auto oeq = FCmpOEQ(operand_value, zero); + auto olt = FCmpOLT(operand_value, zero); return Select(oeq, zero, Select(olt, llvm::ConstantFP::get(type, -1.0), llvm::ConstantFP::get(type, 1.0))); @@ -459,24 +455,24 @@ StatusOr ElementalIrEmitter::EmitFloatUnaryOp( auto abs_value = llvm_ir::EmitCallToIntrinsic( llvm::Intrinsic::fabs, {operand_value}, {type}, b_); auto infinity = llvm::ConstantFP::getInfinity(type); - auto not_infinite = b_->CreateFCmpONE(abs_value, infinity); + auto not_infinite = FCmpONE(abs_value, infinity); return b_->CreateZExt(not_infinite, llvm_ir::PrimitiveTypeToIrType(PRED, module_)); } case HloOpcode::kNegate: - return b_->CreateFNeg(operand_value); + return FNeg(operand_value); case HloOpcode::kReal: return operand_value; case HloOpcode::kImag: return llvm::ConstantFP::get(operand_value->getType(), 0.0); default: return Unimplemented("unary floating-point op '%s'", - HloOpcodeString(op->opcode()).c_str()); + HloOpcodeString(op->opcode())); } } StatusOr ElementalIrEmitter::EmitComplexUnaryOp( - const HloInstruction* op, llvm::Value* operand_value) const { + const HloInstruction* op, llvm::Value* operand_value) { PrimitiveType input_type = op->operand(0)->shape().element_type(); PrimitiveType component_type = primitive_util::IsComplexType(input_type) @@ -488,12 +484,11 @@ StatusOr ElementalIrEmitter::EmitComplexUnaryOp( auto a = EmitExtractReal(operand_value); auto b = EmitExtractImag(operand_value); llvm::Type* llvm_ty = a->getType(); - auto sum_sq = b_->CreateFAdd(b_->CreateFMul(a, a), b_->CreateFMul(b, b)); + auto sum_sq = FAdd(FMul(a, a), FMul(b, b)); TF_ASSIGN_OR_RETURN(auto log_sum_sq, EmitLog(component_type, sum_sq)); TF_ASSIGN_OR_RETURN(auto angle, EmitAtan2(component_type, b, a)); auto one_half = llvm::ConstantFP::get(llvm_ty, 0.5); - return EmitComposeComplex(op, b_->CreateFMul(one_half, log_sum_sq), - angle); + return EmitComposeComplex(op, FMul(one_half, log_sum_sq), angle); } case HloOpcode::kLog1p: { // log1p(a+bi) = .5*log((a+1)^2+b^2) + i*atan2(b, a + 1) @@ -501,14 +496,12 @@ StatusOr ElementalIrEmitter::EmitComplexUnaryOp( auto b = EmitExtractImag(operand_value); llvm::Type* llvm_ty = a->getType(); auto one = llvm::ConstantFP::get(llvm_ty, 1.0); - auto a_plus_one = b_->CreateFAdd(a, one); - auto sum_sq = b_->CreateFAdd(b_->CreateFMul(a_plus_one, a_plus_one), - b_->CreateFMul(b, b)); + auto a_plus_one = FAdd(a, one); + auto sum_sq = FAdd(FMul(a_plus_one, a_plus_one), FMul(b, b)); TF_ASSIGN_OR_RETURN(auto log_sum_sq, EmitLog(component_type, sum_sq)); TF_ASSIGN_OR_RETURN(auto angle, EmitAtan2(component_type, b, a_plus_one)); auto one_half = llvm::ConstantFP::get(llvm_ty, 0.5); - return EmitComposeComplex(op, b_->CreateFMul(one_half, log_sum_sq), - angle); + return EmitComposeComplex(op, FMul(one_half, log_sum_sq), angle); } case HloOpcode::kConvert: { PrimitiveType from_type = op->operand(0)->shape().element_type(); @@ -522,11 +515,9 @@ StatusOr ElementalIrEmitter::EmitComplexUnaryOp( primitive_util::ComplexComponentType(to_type); auto to_ir_component_type = llvm_ir::PrimitiveTypeToIrType(to_component_type, module_); - return EmitComposeComplex(op, - b_->CreateFPCast(EmitExtractReal(operand_value), - to_ir_component_type), - b_->CreateFPCast(EmitExtractImag(operand_value), - to_ir_component_type)); + return EmitComposeComplex( + op, FPCast(EmitExtractReal(operand_value), to_ir_component_type), + FPCast(EmitExtractImag(operand_value), to_ir_component_type)); } case HloOpcode::kExp: { // e^(a+bi) = e^a*(cos(b)+sin(b)i) @@ -536,8 +527,7 @@ StatusOr ElementalIrEmitter::EmitComplexUnaryOp( auto cos_b, EmitCos(component_type, EmitExtractImag(operand_value))); TF_ASSIGN_OR_RETURN( auto sin_b, EmitSin(component_type, EmitExtractImag(operand_value))); - return EmitComposeComplex(op, b_->CreateFMul(exp_a, cos_b), - b_->CreateFMul(exp_a, sin_b)); + return EmitComposeComplex(op, FMul(exp_a, cos_b), FMul(exp_a, sin_b)); } case HloOpcode::kExpm1: { // e^(a+bi)-1 = (e^a*cos(b)-1)+e^a*sin(b)i @@ -548,8 +538,8 @@ StatusOr ElementalIrEmitter::EmitComplexUnaryOp( TF_ASSIGN_OR_RETURN( auto sin_b, EmitSin(component_type, EmitExtractImag(operand_value))); auto one = llvm::ConstantFP::get(exp_a->getType(), 1.0); - auto real_result = b_->CreateFSub(b_->CreateFMul(exp_a, cos_b), one); - auto imag_result = b_->CreateFMul(exp_a, sin_b); + auto real_result = FSub(FMul(exp_a, cos_b), one); + auto imag_result = FMul(exp_a, sin_b); return EmitComposeComplex(op, real_result, imag_result); } case HloOpcode::kCos: { @@ -564,14 +554,13 @@ StatusOr ElementalIrEmitter::EmitComplexUnaryOp( auto b = EmitExtractImag(operand_value); auto type = a->getType(); TF_ASSIGN_OR_RETURN(auto exp_b, EmitExp(component_type, b)); - auto half_exp_b = b_->CreateFMul(llvm::ConstantFP::get(type, 0.5), exp_b); - auto half_exp_neg_b = - b_->CreateFDiv(llvm::ConstantFP::get(type, 0.5), exp_b); + auto half_exp_b = FMul(llvm::ConstantFP::get(type, 0.5), exp_b); + auto half_exp_neg_b = FDiv(llvm::ConstantFP::get(type, 0.5), exp_b); TF_ASSIGN_OR_RETURN(auto cos_a, EmitCos(component_type, a)); TF_ASSIGN_OR_RETURN(auto sin_a, EmitSin(component_type, a)); - return EmitComposeComplex( - op, b_->CreateFMul(cos_a, b_->CreateFAdd(half_exp_neg_b, half_exp_b)), - b_->CreateFMul(sin_a, b_->CreateFSub(half_exp_neg_b, half_exp_b))); + return EmitComposeComplex(op, + FMul(cos_a, FAdd(half_exp_neg_b, half_exp_b)), + FMul(sin_a, FSub(half_exp_neg_b, half_exp_b))); } case HloOpcode::kSin: { // sin(z) = .5i(e^(-iz) - e^(iz)) @@ -587,14 +576,13 @@ StatusOr ElementalIrEmitter::EmitComplexUnaryOp( auto b = EmitExtractImag(operand_value); auto type = a->getType(); TF_ASSIGN_OR_RETURN(auto exp_b, EmitExp(component_type, b)); - auto half_exp_b = b_->CreateFMul(llvm::ConstantFP::get(type, 0.5), exp_b); - auto half_exp_neg_b = - b_->CreateFDiv(llvm::ConstantFP::get(type, 0.5), exp_b); + auto half_exp_b = FMul(llvm::ConstantFP::get(type, 0.5), exp_b); + auto half_exp_neg_b = FDiv(llvm::ConstantFP::get(type, 0.5), exp_b); TF_ASSIGN_OR_RETURN(auto cos_a, EmitCos(component_type, a)); TF_ASSIGN_OR_RETURN(auto sin_a, EmitSin(component_type, a)); - return EmitComposeComplex( - op, b_->CreateFMul(sin_a, b_->CreateFAdd(half_exp_b, half_exp_neg_b)), - b_->CreateFMul(cos_a, b_->CreateFSub(half_exp_b, half_exp_neg_b))); + return EmitComposeComplex(op, + FMul(sin_a, FAdd(half_exp_b, half_exp_neg_b)), + FMul(cos_a, FSub(half_exp_b, half_exp_neg_b))); } case HloOpcode::kTanh: { /* @@ -622,74 +610,63 @@ StatusOr ElementalIrEmitter::EmitComplexUnaryOp( TF_ASSIGN_OR_RETURN(auto exp_a, EmitExp(component_type, a)); TF_ASSIGN_OR_RETURN(auto cos_b, EmitCos(component_type, b)); TF_ASSIGN_OR_RETURN(auto sin_b, EmitSin(component_type, b)); - auto exp_neg_a = - b_->CreateFDiv(llvm::ConstantFP::get(exp_a->getType(), 1), exp_a); - auto exp_2a_minus_exp_neg_2a = b_->CreateFSub( - b_->CreateFMul(exp_a, exp_a), b_->CreateFMul(exp_neg_a, exp_neg_a)); - auto cos_b_sq = b_->CreateFMul(cos_b, cos_b); - auto sin_b_sq = b_->CreateFMul(sin_b, sin_b); - auto real_num = - b_->CreateFAdd(b_->CreateFMul(cos_b_sq, exp_2a_minus_exp_neg_2a), - b_->CreateFMul(sin_b_sq, exp_2a_minus_exp_neg_2a)); - auto cos_b_sin_b = b_->CreateFMul(cos_b, sin_b); - auto exp_a_plus_exp_neg_a = b_->CreateFAdd(exp_a, exp_neg_a); + auto exp_neg_a = FDiv(llvm::ConstantFP::get(exp_a->getType(), 1), exp_a); + auto exp_2a_minus_exp_neg_2a = + FSub(FMul(exp_a, exp_a), FMul(exp_neg_a, exp_neg_a)); + auto cos_b_sq = FMul(cos_b, cos_b); + auto sin_b_sq = FMul(sin_b, sin_b); + auto real_num = FAdd(FMul(cos_b_sq, exp_2a_minus_exp_neg_2a), + FMul(sin_b_sq, exp_2a_minus_exp_neg_2a)); + auto cos_b_sin_b = FMul(cos_b, sin_b); + auto exp_a_plus_exp_neg_a = FAdd(exp_a, exp_neg_a); auto exp_a_plus_exp_neg_a_sq = - b_->CreateFMul(exp_a_plus_exp_neg_a, exp_a_plus_exp_neg_a); - auto exp_a_minus_exp_neg_a = b_->CreateFSub(exp_a, exp_neg_a); + FMul(exp_a_plus_exp_neg_a, exp_a_plus_exp_neg_a); + auto exp_a_minus_exp_neg_a = FSub(exp_a, exp_neg_a); auto exp_a_minus_exp_neg_a_sq = - b_->CreateFMul(exp_a_minus_exp_neg_a, exp_a_minus_exp_neg_a); - auto imag_num = b_->CreateFMul( - cos_b_sin_b, - b_->CreateFSub(exp_a_plus_exp_neg_a_sq, exp_a_minus_exp_neg_a_sq)); - auto denom = - b_->CreateFAdd(b_->CreateFMul(cos_b_sq, exp_a_plus_exp_neg_a_sq), - b_->CreateFMul(sin_b_sq, exp_a_minus_exp_neg_a_sq)); - return EmitComposeComplex(op, b_->CreateFDiv(real_num, denom), - b_->CreateFDiv(imag_num, denom)); + FMul(exp_a_minus_exp_neg_a, exp_a_minus_exp_neg_a); + auto imag_num = FMul( + cos_b_sin_b, FSub(exp_a_plus_exp_neg_a_sq, exp_a_minus_exp_neg_a_sq)); + auto denom = FAdd(FMul(cos_b_sq, exp_a_plus_exp_neg_a_sq), + FMul(sin_b_sq, exp_a_minus_exp_neg_a_sq)); + return EmitComposeComplex(op, FDiv(real_num, denom), + FDiv(imag_num, denom)); } case HloOpcode::kAbs: { - auto sum_sq = - b_->CreateFAdd(b_->CreateFMul(EmitExtractReal(operand_value), - EmitExtractReal(operand_value)), - b_->CreateFMul(EmitExtractImag(operand_value), - EmitExtractImag(operand_value))); + auto sum_sq = FAdd( + FMul(EmitExtractReal(operand_value), EmitExtractReal(operand_value)), + FMul(EmitExtractImag(operand_value), EmitExtractImag(operand_value))); return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sqrt, {sum_sq}, {sum_sq->getType()}, b_); } case HloOpcode::kSign: { // Sign(c) = c / |c| - auto sum_sq = - b_->CreateFAdd(b_->CreateFMul(EmitExtractReal(operand_value), - EmitExtractReal(operand_value)), - b_->CreateFMul(EmitExtractImag(operand_value), - EmitExtractImag(operand_value))); + auto sum_sq = FAdd( + FMul(EmitExtractReal(operand_value), EmitExtractReal(operand_value)), + FMul(EmitExtractImag(operand_value), EmitExtractImag(operand_value))); auto cplx_abs = llvm_ir::EmitCallToIntrinsic( llvm::Intrinsic::sqrt, {sum_sq}, {sum_sq->getType()}, b_); auto type = cplx_abs->getType(); auto zero = llvm::ConstantFP::get(type, 0.0); - auto oeq = b_->CreateFCmpOEQ(cplx_abs, zero); + auto oeq = FCmpOEQ(cplx_abs, zero); return Select( oeq, EmitComposeComplex(op, zero, zero), - EmitComposeComplex( - op, b_->CreateFDiv(EmitExtractReal(operand_value), cplx_abs), - b_->CreateFDiv(EmitExtractImag(operand_value), cplx_abs))); + EmitComposeComplex(op, FDiv(EmitExtractReal(operand_value), cplx_abs), + FDiv(EmitExtractImag(operand_value), cplx_abs))); } case HloOpcode::kNegate: - return EmitComposeComplex(op, - b_->CreateFNeg(EmitExtractReal(operand_value)), - b_->CreateFNeg(EmitExtractImag(operand_value))); + return EmitComposeComplex(op, FNeg(EmitExtractReal(operand_value)), + FNeg(EmitExtractImag(operand_value))); case HloOpcode::kReal: return EmitExtractReal(operand_value); case HloOpcode::kImag: return EmitExtractImag(operand_value); default: return Unimplemented("unary complex op '%s'", - HloOpcodeString(op->opcode()).c_str()); + HloOpcodeString(op->opcode())); } } StatusOr ElementalIrEmitter::EmitBinaryOp( - const HloInstruction* op, llvm::Value* lhs_value, - llvm::Value* rhs_value) const { + const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) { PrimitiveType operand_type = op->operand(0)->shape().element_type(); if (ShapeUtil::ElementIsIntegral(op->operand(0)->shape()) || operand_type == PRED) { @@ -704,21 +681,20 @@ StatusOr ElementalIrEmitter::EmitBinaryOp( } StatusOr ElementalIrEmitter::EmitFloatBinaryOp( - const HloInstruction* op, llvm::Value* lhs_value, - llvm::Value* rhs_value) const { + const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) { switch (op->opcode()) { case HloOpcode::kComplex: return EmitComposeComplex(op, lhs_value, rhs_value); case HloOpcode::kAdd: - return b_->CreateFAdd(lhs_value, rhs_value); + return FAdd(lhs_value, rhs_value); case HloOpcode::kSubtract: - return b_->CreateFSub(lhs_value, rhs_value); + return FSub(lhs_value, rhs_value); case HloOpcode::kMultiply: - return b_->CreateFMul(lhs_value, rhs_value); + return FMul(lhs_value, rhs_value); case HloOpcode::kDivide: - return b_->CreateFDiv(lhs_value, rhs_value); + return FDiv(lhs_value, rhs_value); case HloOpcode::kRemainder: - return b_->CreateFRem(lhs_value, rhs_value); + return FRem(lhs_value, rhs_value); // LLVM comparisons can be "unordered" (U) or "ordered" (O) -- ordered // comparisons always return false when one of the operands is NaN, whereas // unordered comparisons return true. @@ -755,66 +731,52 @@ StatusOr ElementalIrEmitter::EmitFloatBinaryOp( return EmitAtan2(op->shape().element_type(), lhs_value, rhs_value); default: return Unimplemented("binary floating point op '%s'", - HloOpcodeString(op->opcode()).c_str()); + HloOpcodeString(op->opcode())); } } StatusOr ElementalIrEmitter::EmitComplexBinaryOp( - const HloInstruction* op, llvm::Value* lhs_value, - llvm::Value* rhs_value) const { + const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) { switch (op->opcode()) { case HloOpcode::kAdd: - return EmitComposeComplex(op, - b_->CreateFAdd(EmitExtractReal(lhs_value), - EmitExtractReal(rhs_value)), - b_->CreateFAdd(EmitExtractImag(lhs_value), - EmitExtractImag(rhs_value))); + return EmitComposeComplex( + op, FAdd(EmitExtractReal(lhs_value), EmitExtractReal(rhs_value)), + FAdd(EmitExtractImag(lhs_value), EmitExtractImag(rhs_value))); case HloOpcode::kSubtract: - return EmitComposeComplex(op, - b_->CreateFSub(EmitExtractReal(lhs_value), - EmitExtractReal(rhs_value)), - b_->CreateFSub(EmitExtractImag(lhs_value), - EmitExtractImag(rhs_value))); + return EmitComposeComplex( + op, FSub(EmitExtractReal(lhs_value), EmitExtractReal(rhs_value)), + FSub(EmitExtractImag(lhs_value), EmitExtractImag(rhs_value))); case HloOpcode::kMultiply: return EmitComposeComplex( op, - b_->CreateFSub(b_->CreateFMul(EmitExtractReal(lhs_value), - EmitExtractReal(rhs_value)), - b_->CreateFMul(EmitExtractImag(lhs_value), - EmitExtractImag(rhs_value))), - b_->CreateFAdd(b_->CreateFMul(EmitExtractReal(lhs_value), - EmitExtractImag(rhs_value)), - b_->CreateFMul(EmitExtractImag(lhs_value), - EmitExtractReal(rhs_value)))); + FSub(FMul(EmitExtractReal(lhs_value), EmitExtractReal(rhs_value)), + FMul(EmitExtractImag(lhs_value), EmitExtractImag(rhs_value))), + FAdd(FMul(EmitExtractReal(lhs_value), EmitExtractImag(rhs_value)), + FMul(EmitExtractImag(lhs_value), EmitExtractReal(rhs_value)))); case HloOpcode::kDivide: { // (a+bi) / (c+di) = ((a+bi)(c-di)) / ((c+di)(c-di)) // = ((ac + bd) + (bc - ad)i) / (c^2 + d^2) auto rhs_sum_sq = - b_->CreateFAdd(b_->CreateFMul(EmitExtractReal(rhs_value), - EmitExtractReal(rhs_value)), - b_->CreateFMul(EmitExtractImag(rhs_value), - EmitExtractImag(rhs_value))); + FAdd(FMul(EmitExtractReal(rhs_value), EmitExtractReal(rhs_value)), + FMul(EmitExtractImag(rhs_value), EmitExtractImag(rhs_value))); auto type = rhs_sum_sq->getType(); auto zero = llvm::ConstantFP::get(type, 0.0); - auto oeq = b_->CreateFCmpOEQ(rhs_sum_sq, zero); - auto real_inf_or_nan = b_->CreateFDiv(EmitExtractReal(lhs_value), zero); - auto imag_inf_or_nan = b_->CreateFDiv(EmitExtractImag(lhs_value), zero); + auto oeq = FCmpOEQ(rhs_sum_sq, zero); + auto real_inf_or_nan = FDiv(EmitExtractReal(lhs_value), zero); + auto imag_inf_or_nan = FDiv(EmitExtractImag(lhs_value), zero); return Select( oeq, EmitComposeComplex(op, real_inf_or_nan, imag_inf_or_nan), - EmitComposeComplex( - op, - b_->CreateFDiv( - b_->CreateFAdd(b_->CreateFMul(EmitExtractReal(lhs_value), - EmitExtractReal(rhs_value)), - b_->CreateFMul(EmitExtractImag(lhs_value), - EmitExtractImag(rhs_value))), - rhs_sum_sq), - b_->CreateFDiv( - b_->CreateFSub(b_->CreateFMul(EmitExtractImag(lhs_value), - EmitExtractReal(rhs_value)), - b_->CreateFMul(EmitExtractReal(lhs_value), - EmitExtractImag(rhs_value))), - rhs_sum_sq))); + EmitComposeComplex(op, + FDiv(FAdd(FMul(EmitExtractReal(lhs_value), + EmitExtractReal(rhs_value)), + FMul(EmitExtractImag(lhs_value), + EmitExtractImag(rhs_value))), + rhs_sum_sq), + FDiv(FSub(FMul(EmitExtractImag(lhs_value), + EmitExtractReal(rhs_value)), + FMul(EmitExtractReal(lhs_value), + EmitExtractImag(rhs_value))), + rhs_sum_sq))); } // LLVM comparisons can be "unordered" (U) or "ordered" (O) -- ordered // comparisons always return false when one of the operands is NaN, whereas @@ -824,21 +786,19 @@ StatusOr ElementalIrEmitter::EmitComplexBinaryOp( // unordered comparison. This makes x != y equivalent to !(x == y), and // matches C++'s semantics. case HloOpcode::kEq: - return b_->CreateAnd( - llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ, - EmitExtractReal(lhs_value), - EmitExtractReal(rhs_value), b_), - llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ, - EmitExtractImag(lhs_value), - EmitExtractImag(rhs_value), b_)); + return And(llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ, + EmitExtractReal(lhs_value), + EmitExtractReal(rhs_value), b_), + llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ, + EmitExtractImag(lhs_value), + EmitExtractImag(rhs_value), b_)); case HloOpcode::kNe: - return b_->CreateOr( - llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE, - EmitExtractReal(lhs_value), - EmitExtractReal(rhs_value), b_), - llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE, - EmitExtractImag(lhs_value), - EmitExtractImag(rhs_value), b_)); + return Or(llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE, + EmitExtractReal(lhs_value), + EmitExtractReal(rhs_value), b_), + llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE, + EmitExtractImag(lhs_value), + EmitExtractImag(rhs_value), b_)); case HloOpcode::kPower: { // (a+bi)^(c+di) = @@ -850,68 +810,71 @@ StatusOr ElementalIrEmitter::EmitComplexBinaryOp( auto b = EmitExtractImag(lhs_value); auto c = EmitExtractReal(rhs_value); auto d = EmitExtractImag(rhs_value); - auto aa_p_bb = b_->CreateFAdd(b_->CreateFMul(a, a), b_->CreateFMul(b, b)); + auto aa_p_bb = FAdd(FMul(a, a), FMul(b, b)); auto one_half = llvm::ConstantFP::get(a->getType(), 0.5); - auto half_c = b_->CreateFMul(one_half, c); + auto half_c = FMul(one_half, c); TF_ASSIGN_OR_RETURN(auto aa_p_bb_to_half_c, EmitPow(component_type, aa_p_bb, half_c)); - auto neg_d = b_->CreateFNeg(d); + auto neg_d = FNeg(d); TF_ASSIGN_OR_RETURN(auto arg_lhs, EmitAtan2(component_type, b, a)); - auto neg_d_arg_lhs = b_->CreateFMul(neg_d, arg_lhs); + auto neg_d_arg_lhs = FMul(neg_d, arg_lhs); TF_ASSIGN_OR_RETURN(auto e_to_neg_d_arg_lhs, EmitExp(component_type, neg_d_arg_lhs)); - auto coeff = b_->CreateFMul(aa_p_bb_to_half_c, e_to_neg_d_arg_lhs); + auto coeff = FMul(aa_p_bb_to_half_c, e_to_neg_d_arg_lhs); TF_ASSIGN_OR_RETURN(auto ln_aa_p_bb, EmitLog(component_type, aa_p_bb)); - auto half_d = b_->CreateFMul(one_half, d); - auto q = b_->CreateFAdd(b_->CreateFMul(c, arg_lhs), - b_->CreateFMul(half_d, ln_aa_p_bb)); + auto half_d = FMul(one_half, d); + auto q = FAdd(FMul(c, arg_lhs), FMul(half_d, ln_aa_p_bb)); TF_ASSIGN_OR_RETURN(auto cos_q, EmitCos(component_type, q)); TF_ASSIGN_OR_RETURN(auto sin_q, EmitSin(component_type, q)); - return EmitComposeComplex(op, b_->CreateFMul(coeff, cos_q), - b_->CreateFMul(coeff, sin_q)); + return EmitComposeComplex(op, FMul(coeff, cos_q), FMul(coeff, sin_q)); } default: return Unimplemented("binary complex op '%s'", - HloOpcodeString(op->opcode()).c_str()); + HloOpcodeString(op->opcode())); } } llvm::Value* ElementalIrEmitter::EmitFloatMax(llvm::Value* lhs_value, - llvm::Value* rhs_value) const { + llvm::Value* rhs_value) { return llvm_ir::EmitFloatMax(lhs_value, rhs_value, b_); } llvm::Value* ElementalIrEmitter::EmitFloatMin(llvm::Value* lhs_value, - llvm::Value* rhs_value) const { + llvm::Value* rhs_value) { return llvm_ir::EmitFloatMin(lhs_value, rhs_value, b_); } StatusOr ElementalIrEmitter::EmitErfInv(PrimitiveType prim_type, - llvm::Value* x) const { - if (prim_type != F32) { - // TODO(b/34339814): Implement inverse erf for F64. + llvm::Value* x) { + if (prim_type != F16 && prim_type != F32 && prim_type != F64) { return Unimplemented( "Inverse erf is only implemented for element " - "type F32."); + "types F16, F32 and F64."); + } + + // Upcast half to float. + if (prim_type == F16) { + x = b_->CreateFPExt(x, b_->getFloatTy()); } - auto getFloat = [&](const float f) { - return llvm::ConstantFP::get(b_->getFloatTy(), f); + + auto get_float = [&](const double f) { + return llvm::ConstantFP::get(x->getType(), f); }; - auto multiply_add = [&](tensorflow::gtl::ArraySlice coefficients, + auto multiply_add = [&](absl::Span coefficients, llvm::Value* w) { - llvm::Value* p = getFloat(coefficients.front()); - coefficients.pop_front(); + llvm::Value* p = get_float(coefficients.front()); + coefficients.remove_prefix(1); for (float coefficient : coefficients) { - p = b_->CreateFAdd(b_->CreateFMul(p, w), getFloat(coefficient)); + p = FAdd(FMul(p, w), get_float(coefficient)); } return p; }; // Approximation for inverse error function from // Giles, M., "Approximating the erfinv function". - // The approximation has the form: - // w = log((1-x)*(1+x)) + // The approximation has the form (float version): + // w = -log((1-x)*(1+x)) // if ( w < 5 ) { // w = w - 2.5 // p = sum_{i=1}^n lq[i]*w^i @@ -921,105 +884,179 @@ StatusOr ElementalIrEmitter::EmitErfInv(PrimitiveType prim_type, // } // return p*x llvm::Function* logf_fn = llvm::Intrinsic::getDeclaration( - module_, llvm::Intrinsic::log, {b_->getFloatTy()}); + module_, llvm::Intrinsic::log, {x->getType()}); - llvm::Value* w = b_->CreateFNeg(b_->CreateCall( - logf_fn, {b_->CreateFMul(b_->CreateFSub(getFloat(1.0f), x), - b_->CreateFAdd(getFloat(1.0f), x))})); + llvm::Value* w = FNeg(Call( + logf_fn, {FMul(FSub(get_float(1.0f), x), FAdd(get_float(1.0f), x))})); llvm::Value* p_addr = - llvm_ir::EmitAllocaAtFunctionEntry(b_->getFloatTy(), "p.addr", b_); + llvm_ir::EmitAllocaAtFunctionEntry(x->getType(), "p.addr", b_); + + if (prim_type == F16 || prim_type == F32) { + llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse( + FCmpOLT(w, get_float(5.0f)), "w_less_than_five", b_); + // Handle true BB. + SetToFirstInsertPoint(if_data.true_block, b_); + { + llvm::Value* lw = FSub(w, get_float(2.5f)); + absl::Span lq{ + 2.81022636e-08f, 3.43273939e-07f, -3.5233877e-06f, + -4.39150654e-06f, 0.00021858087f, -0.00125372503f, + -0.00417768164f, 0.246640727f, 1.50140941f}; + llvm::Value* p = multiply_add(lq, lw); + Store(p, p_addr); + } - llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse( - b_->CreateFCmpOLT(w, getFloat(5.0f)), "w_less_than_five", b_); - // Handle true BB. - SetToFirstInsertPoint(if_data.true_block, b_); - { - llvm::Value* lw = b_->CreateFSub(w, getFloat(2.5f)); - tensorflow::gtl::ArraySlice lq{ - 2.81022636e-08f, 3.43273939e-07f, -3.5233877e-06f, - -4.39150654e-06f, 0.00021858087f, -0.00125372503f, - -0.00417768164f, 0.246640727f, 1.50140941f}; - llvm::Value* p = multiply_add(lq, lw); - b_->CreateStore(p, p_addr); - } + // Handle false BB. + SetToFirstInsertPoint(if_data.false_block, b_); + { + llvm::Function* sqrtf_fn = llvm::Intrinsic::getDeclaration( + module_, llvm::Intrinsic::sqrt, {b_->getFloatTy()}); + + llvm::Value* gw = FSub(Call(sqrtf_fn, w), get_float(3.0f)); + absl::Span gq{ + -0.000200214257f, 0.000100950558f, 0.00134934322f, + -0.00367342844f, 0.00573950773f, -0.0076224613f, + 0.00943887047f, 1.00167406f, 2.83297682f}; + llvm::Value* p = multiply_add(gq, gw); + Store(p, p_addr); + } - // Handle false BB. - SetToFirstInsertPoint(if_data.false_block, b_); - { - llvm::Function* sqrtf_fn = llvm::Intrinsic::getDeclaration( - module_, llvm::Intrinsic::sqrt, {b_->getFloatTy()}); - - llvm::Value* gw = - b_->CreateFSub(b_->CreateCall(sqrtf_fn, {w}), getFloat(3.0f)); - tensorflow::gtl::ArraySlice gq{ - -0.000200214257f, 0.000100950558f, 0.00134934322f, - -0.00367342844f, 0.00573950773f, -0.0076224613f, - 0.00943887047f, 1.00167406f, 2.83297682f}; - llvm::Value* p = multiply_add(gq, gw); - b_->CreateStore(p, p_addr); - } + SetToFirstInsertPoint(if_data.after_block, b_); + } else { + DCHECK(prim_type == F64); + + llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse( + FCmpOLT(w, get_float(6.25)), "w_less_than_6.25", b_); + + SetToFirstInsertPoint(if_data.true_block, b_); + { + llvm::Value* lw = FSub(w, get_float(3.125)); + absl::Span c{ + -3.6444120640178196996e-21, -1.685059138182016589e-19, + 1.2858480715256400167e-18, 1.115787767802518096e-17, + -1.333171662854620906e-16, 2.0972767875968561637e-17, + 6.6376381343583238325e-15, -4.0545662729752068639e-14, + -8.1519341976054721522e-14, 2.6335093153082322977e-12, + -1.2975133253453532498e-11, -5.4154120542946279317e-11, + 1.051212273321532285e-09, -4.1126339803469836976e-09, + -2.9070369957882005086e-08, 4.2347877827932403518e-07, + -1.3654692000834678645e-06, -1.3882523362786468719e-05, + 0.0001867342080340571352, -0.00074070253416626697512, + -0.0060336708714301490533, 0.24015818242558961693, + 1.6536545626831027356}; + llvm::Value* p = multiply_add(c, lw); + Store(p, p_addr); + } - SetToFirstInsertPoint(if_data.after_block, b_); - llvm::Value* p = b_->CreateLoad(p_addr); - return b_->CreateFMul(p, x); + SetToFirstInsertPoint(if_data.false_block, b_); + llvm_ir::LlvmIfData if_data_second = llvm_ir::EmitIfThenElse( + FCmpOLT(w, get_float(16.0)), "w_less_than_16", b_); + SetToFirstInsertPoint(if_data_second.true_block, b_); + { + llvm::Function* sqrtf_fn = llvm::Intrinsic::getDeclaration( + module_, llvm::Intrinsic::sqrt, {b_->getDoubleTy()}); + + llvm::Value* gw = FSub(Call(sqrtf_fn, w), get_float(3.25)); + absl::Span t1{ + 2.2137376921775787049e-09, 9.0756561938885390979e-08, + -2.7517406297064545428e-07, 1.8239629214389227755e-08, + 1.5027403968909827627e-06, -4.013867526981545969e-06, + 2.9234449089955446044e-06, 1.2475304481671778723e-05, + -4.7318229009055733981e-05, 6.8284851459573175448e-05, + 2.4031110387097893999e-05, -0.0003550375203628474796, + 0.00095328937973738049703, -0.0016882755560235047313, + 0.0024914420961078508066, -0.0037512085075692412107, + 0.005370914553590063617, 1.0052589676941592334, + 3.0838856104922207635}; + llvm::Value* p = multiply_add(t1, gw); + Store(p, p_addr); + } + + SetToFirstInsertPoint(if_data_second.false_block, b_); + { + llvm::Function* sqrtf_fn = llvm::Intrinsic::getDeclaration( + module_, llvm::Intrinsic::sqrt, {b_->getDoubleTy()}); + + llvm::Value* gw = FSub(Call(sqrtf_fn, w), get_float(5.0)); + absl::Span t2{ + -2.7109920616438573243e-11, -2.5556418169965252055e-10, + 1.5076572693500548083e-09, -3.7894654401267369937e-09, + 7.6157012080783393804e-09, -1.4960026627149240478e-08, + 2.9147953450901080826e-08, -6.7711997758452339498e-08, + 2.2900482228026654717e-07, -9.9298272942317002539e-07, + 4.5260625972231537039e-06, -1.9681778105531670567e-05, + 7.5995277030017761139e-05, -0.00021503011930044477347, + -0.00013871931833623122026, 1.0103004648645343977, + 4.8499064014085844221}; + llvm::Value* p = multiply_add(t2, gw); + Store(p, p_addr); + } + + SetToFirstInsertPoint(if_data.after_block, b_); + } + llvm::Value* p = Load(p_addr); + x = FMul(p, x); + // Trunc back to half if needed. + if (prim_type == F16) { + x = b_->CreateFPTrunc(x, b_->getHalfTy()); + } + return x; } -StatusOr ElementalIrEmitter::EmitErfcInv( - PrimitiveType prim_type, llvm::Value* value) const { +StatusOr ElementalIrEmitter::EmitErfcInv(PrimitiveType prim_type, + llvm::Value* value) { // Compute erfcinv(value) by calculating erfinv(1.0 - value). auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, module_); auto one = llvm::ConstantFP::get(type, 1.0); - return EmitErfInv(prim_type, b_->CreateFSub(one, value)); + return EmitErfInv(prim_type, FSub(one, value)); } StatusOr ElementalIrEmitter::EmitLog(PrimitiveType prim_type, - llvm::Value* value) const { + llvm::Value* value) { return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::log, {value}, {value->getType()}, b_); } StatusOr ElementalIrEmitter::EmitLog1p(PrimitiveType prim_type, - llvm::Value* value) const { + llvm::Value* value) { auto x = value; auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, module_); auto one = llvm::ConstantFP::get(type, 1.0); auto negative_half = llvm::ConstantFP::get(type, -0.5); // When x is large, the naive evaluation of ln(x + 1) is more // accurate than the Taylor series. - TF_ASSIGN_OR_RETURN(auto for_large_x, - EmitLog(prim_type, b_->CreateFAdd(x, one))); + TF_ASSIGN_OR_RETURN(auto for_large_x, EmitLog(prim_type, FAdd(x, one))); // The Taylor series for ln(x+1) is x - x^2/2 - x^3/3 + …. - auto for_small_x = - b_->CreateFMul(b_->CreateFAdd(b_->CreateFMul(negative_half, x), one), x); + auto for_small_x = FMul(FAdd(FMul(negative_half, x), one), x); const auto kAntilogarithmIsSmallThreshold = 1e-4; auto abs_x = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {value}, {type}, b_); - auto x_is_small = b_->CreateFCmpOLT( + auto x_is_small = FCmpOLT( abs_x, llvm::ConstantFP::get(type, kAntilogarithmIsSmallThreshold)); return Select(x_is_small, for_small_x, for_large_x); } StatusOr ElementalIrEmitter::EmitSin(PrimitiveType prim_type, - llvm::Value* value) const { + llvm::Value* value) { return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sin, {value}, {value->getType()}, b_); } StatusOr ElementalIrEmitter::EmitCos(PrimitiveType prim_type, - llvm::Value* value) const { + llvm::Value* value) { return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::cos, {value}, {value->getType()}, b_); } StatusOr ElementalIrEmitter::EmitExp(PrimitiveType prim_type, - llvm::Value* value) const { + llvm::Value* value) { return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::exp, {value}, {value->getType()}, b_); } StatusOr ElementalIrEmitter::EmitExpm1(PrimitiveType prim_type, - llvm::Value* value) const { + llvm::Value* value) { auto x = value; auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, module_); auto one = llvm::ConstantFP::get(type, 1.0); @@ -1027,40 +1064,40 @@ StatusOr ElementalIrEmitter::EmitExpm1(PrimitiveType prim_type, // When the exponent is large, the naive evaluation of e^(x) - 1 is more // accurate than the Taylor series. TF_ASSIGN_OR_RETURN(auto exp_x, EmitExp(prim_type, value)); - auto for_large_x = b_->CreateFSub(exp_x, one); + auto for_large_x = FSub(exp_x, one); // The Taylor series for exp(x) is 1 + x + x^2/2 + x^3/6 + …. // We want exp(x)-1 which is x + x^2/2 + x^3/6 + …. - auto x_squared = b_->CreateFAdd(x, x); - auto x_squared_over_two = b_->CreateFMul(x_squared, half); - auto for_small_x = b_->CreateFAdd(x, x_squared_over_two); + auto x_squared = FAdd(x, x); + auto x_squared_over_two = FMul(x_squared, half); + auto for_small_x = FAdd(x, x_squared_over_two); const auto kExponentIsSmallThreshold = 1e-5; auto abs_x = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {value}, {type}, b_); - auto x_is_small = b_->CreateFCmpOLT( - abs_x, llvm::ConstantFP::get(type, kExponentIsSmallThreshold)); + auto x_is_small = + FCmpOLT(abs_x, llvm::ConstantFP::get(type, kExponentIsSmallThreshold)); return Select(x_is_small, for_small_x, for_large_x); } StatusOr ElementalIrEmitter::EmitPow(PrimitiveType prim_type, llvm::Value* lhs, - llvm::Value* rhs) const { + llvm::Value* rhs) { return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::pow, {lhs, rhs}, {lhs->getType()}, b_); } StatusOr ElementalIrEmitter::EmitAtan2(PrimitiveType prim_type, llvm::Value* lhs, - llvm::Value* rhs) const { + llvm::Value* rhs) { return Unimplemented("atan2"); } StatusOr ElementalIrEmitter::EmitTanh(PrimitiveType prim_type, - llvm::Value* value) const { + llvm::Value* value) { return Unimplemented("tanh"); } StatusOr ElementalIrEmitter::EmitReducePrecision( - const HloInstruction* hlo, llvm::Value* x) const { + const HloInstruction* hlo, llvm::Value* x) { if (hlo->operand(0)->shape().element_type() != F32) { return Unimplemented("reduce-precision only implemented for F32"); } @@ -1091,44 +1128,39 @@ static llvm::Value* SaturateShiftIfNecessary(llvm::IRBuilder<>* b, return b->CreateSelect(shift_amt_in_range, shift_result, saturated_value); } -llvm::Value* ElementalIrEmitter::GetOne(llvm::Type* type) const { +llvm::Value* ElementalIrEmitter::GetOne(llvm::Type* type) { return llvm::ConstantInt::get(llvm::cast(type), 1); } -llvm::Value* ElementalIrEmitter::GetZero(llvm::Type* type) const { +llvm::Value* ElementalIrEmitter::GetZero(llvm::Type* type) { return llvm::ConstantInt::get(llvm::cast(type), 0); } -llvm::Value* ElementalIrEmitter::GetIntSMin(llvm::Type* type) const { +llvm::Value* ElementalIrEmitter::GetIntSMin(llvm::Type* type) { auto* integer_type = llvm::cast(type); return llvm::ConstantInt::get(integer_type, llvm::APInt::getSignedMinValue( integer_type->getBitWidth())); } -llvm::Value* ElementalIrEmitter::GetMinusOne(llvm::Type* type) const { +llvm::Value* ElementalIrEmitter::GetMinusOne(llvm::Type* type) { auto* integer_type = llvm::cast(type); return llvm::ConstantInt::get( integer_type, llvm::APInt::getAllOnesValue(integer_type->getBitWidth())); } -llvm::Value* ElementalIrEmitter::IsZero(llvm::Value* v) const { - return b_->CreateICmpEQ(v, llvm::ConstantInt::get(v->getType(), 0)); -} - -llvm::Value* ElementalIrEmitter::IsIntMinDivisionOverflow( - llvm::Value* lhs, llvm::Value* rhs) const { - return b_->CreateAnd(b_->CreateICmpEQ(lhs, GetIntSMin(lhs->getType())), - b_->CreateICmpEQ(rhs, GetMinusOne(rhs->getType()))); +llvm::Value* ElementalIrEmitter::IsZero(llvm::Value* v) { + return ICmpEQ(v, llvm::ConstantInt::get(v->getType(), 0)); } -llvm::Value* ElementalIrEmitter::Select(llvm::Value* cond, llvm::Value* if_true, - llvm::Value* if_false) const { - return b_->CreateSelect(cond, if_true, if_false); +llvm::Value* ElementalIrEmitter::IsIntMinDivisionOverflow(llvm::Value* lhs, + llvm::Value* rhs) { + return And(ICmpEQ(lhs, GetIntSMin(lhs->getType())), + ICmpEQ(rhs, GetMinusOne(rhs->getType()))); } llvm::Value* ElementalIrEmitter::EmitIntegerDivide(llvm::Value* lhs, llvm::Value* rhs, - bool is_signed) const { + bool is_signed) { // Integer division overflow behavior: // // X / 0 == -1 @@ -1137,16 +1169,15 @@ llvm::Value* ElementalIrEmitter::EmitIntegerDivide(llvm::Value* lhs, if (!is_signed) { llvm::Value* udiv_is_unsafe = IsZero(rhs); llvm::Value* safe_rhs = Select(udiv_is_unsafe, GetOne(lhs->getType()), rhs); - llvm::Value* safe_div = b_->CreateUDiv(lhs, safe_rhs); + llvm::Value* safe_div = UDiv(lhs, safe_rhs); return Select(udiv_is_unsafe, GetMinusOne(lhs->getType()), safe_div); } llvm::Value* has_zero_divisor = IsZero(rhs); llvm::Value* has_int_min_overflow = IsIntMinDivisionOverflow(lhs, rhs); - llvm::Value* sdiv_is_unsafe = - b_->CreateOr(has_int_min_overflow, has_zero_divisor); + llvm::Value* sdiv_is_unsafe = Or(has_int_min_overflow, has_zero_divisor); llvm::Value* safe_rhs = Select(sdiv_is_unsafe, GetOne(lhs->getType()), rhs); - llvm::Value* safe_div = b_->CreateSDiv(lhs, safe_rhs); + llvm::Value* safe_div = SDiv(lhs, safe_rhs); return Select( has_zero_divisor, GetMinusOne(lhs->getType()), @@ -1155,7 +1186,7 @@ llvm::Value* ElementalIrEmitter::EmitIntegerDivide(llvm::Value* lhs, llvm::Value* ElementalIrEmitter::EmitIntegerRemainder(llvm::Value* lhs, llvm::Value* rhs, - bool is_signed) const { + bool is_signed) { // Integer remainder overflow behavior: // // X % 0 == X @@ -1164,16 +1195,15 @@ llvm::Value* ElementalIrEmitter::EmitIntegerRemainder(llvm::Value* lhs, if (!is_signed) { llvm::Value* urem_is_unsafe = IsZero(rhs); llvm::Value* safe_rhs = Select(urem_is_unsafe, GetOne(lhs->getType()), rhs); - llvm::Value* safe_rem = b_->CreateURem(lhs, safe_rhs); + llvm::Value* safe_rem = URem(lhs, safe_rhs); return Select(urem_is_unsafe, lhs, safe_rem); } llvm::Value* has_zero_divisor = IsZero(rhs); llvm::Value* has_int_min_overflow = IsIntMinDivisionOverflow(lhs, rhs); - llvm::Value* srem_is_unsafe = - b_->CreateOr(has_int_min_overflow, has_zero_divisor); + llvm::Value* srem_is_unsafe = Or(has_int_min_overflow, has_zero_divisor); llvm::Value* safe_rhs = Select(srem_is_unsafe, GetOne(lhs->getType()), rhs); - llvm::Value* safe_rem = b_->CreateSRem(lhs, safe_rhs); + llvm::Value* safe_rem = SRem(lhs, safe_rhs); return Select( has_zero_divisor, lhs, @@ -1182,15 +1212,15 @@ llvm::Value* ElementalIrEmitter::EmitIntegerRemainder(llvm::Value* lhs, StatusOr ElementalIrEmitter::EmitIntegerBinaryOp( const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value, - bool is_signed) const { + bool is_signed) { switch (op->opcode()) { // TODO(jingyue): add the "nsw" attribute for signed types. case HloOpcode::kAdd: - return b_->CreateAdd(lhs_value, rhs_value); + return Add(lhs_value, rhs_value); case HloOpcode::kSubtract: - return b_->CreateSub(lhs_value, rhs_value); + return Sub(lhs_value, rhs_value); case HloOpcode::kMultiply: - return b_->CreateMul(lhs_value, rhs_value); + return Mul(lhs_value, rhs_value); case HloOpcode::kDivide: return EmitIntegerDivide(lhs_value, rhs_value, is_signed); case HloOpcode::kRemainder: @@ -1222,11 +1252,11 @@ StatusOr ElementalIrEmitter::EmitIntegerBinaryOp( case HloOpcode::kMaximum: return EmitIntegralMax(lhs_value, rhs_value, is_signed); case HloOpcode::kAnd: - return b_->CreateAnd(lhs_value, rhs_value); + return And(lhs_value, rhs_value); case HloOpcode::kOr: - return b_->CreateOr(lhs_value, rhs_value); + return Or(lhs_value, rhs_value); case HloOpcode::kXor: - return b_->CreateXor(lhs_value, rhs_value); + return Xor(lhs_value, rhs_value); // Shifting out bits >= the number of bits in the type being shifted // produces a poison value in LLVM which is basically "deferred undefined @@ -1235,25 +1265,25 @@ StatusOr ElementalIrEmitter::EmitIntegerBinaryOp( // UB. case HloOpcode::kShiftRightArithmetic: return SaturateShiftIfNecessary(b_, lhs_value, rhs_value, - b_->CreateAShr(lhs_value, rhs_value), + AShr(lhs_value, rhs_value), /*saturate_to_sign_bit=*/true); case HloOpcode::kShiftLeft: return SaturateShiftIfNecessary(b_, lhs_value, rhs_value, - b_->CreateShl(lhs_value, rhs_value), + Shl(lhs_value, rhs_value), /*saturate_to_sign_bit=*/false); case HloOpcode::kShiftRightLogical: return SaturateShiftIfNecessary(b_, lhs_value, rhs_value, - b_->CreateLShr(lhs_value, rhs_value), + LShr(lhs_value, rhs_value), /*saturate_to_sign_bit=*/false); default: return Unimplemented("binary integer op '%s'", - HloOpcodeString(op->opcode()).c_str()); + HloOpcodeString(op->opcode())); } } llvm::Value* ElementalIrEmitter::EmitIntegralMax(llvm::Value* lhs_value, llvm::Value* rhs_value, - bool is_signed) const { + bool is_signed) { return Select(b_->CreateICmp(is_signed ? llvm::ICmpInst::ICMP_SGE : llvm::ICmpInst::ICMP_UGE, lhs_value, rhs_value), @@ -1262,7 +1292,7 @@ llvm::Value* ElementalIrEmitter::EmitIntegralMax(llvm::Value* lhs_value, llvm::Value* ElementalIrEmitter::EmitIntegralMin(llvm::Value* lhs_value, llvm::Value* rhs_value, - bool is_signed) const { + bool is_signed) { return Select(b_->CreateICmp(is_signed ? llvm::ICmpInst::ICMP_SLE : llvm::ICmpInst::ICMP_ULE, lhs_value, rhs_value), @@ -1271,7 +1301,7 @@ llvm::Value* ElementalIrEmitter::EmitIntegralMin(llvm::Value* lhs_value, llvm_ir::IrArray::Index ElementalIrEmitter::ElementwiseSourceIndex( const llvm_ir::IrArray::Index& target_index, const HloInstruction& hlo, - int64 operand_no) const { + int64 operand_no) { CHECK(hlo.IsElementwise()) << "HLO " << hlo.ToString() << " is not elementwise."; @@ -1312,7 +1342,7 @@ llvm_ir::IrArray::Index ElementalIrEmitter::ElementwiseSourceIndex( StatusOr ElementalIrEmitter::ConvertValueForDistribution( const HloInstruction* hlo, const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, - const llvm_ir::IrArray::Index& index, llvm::Value* raw_value) const { + const llvm_ir::IrArray::Index& index, llvm::Value* raw_value) { TF_ASSIGN_OR_RETURN(llvm::Value * a_or_mean, operand_to_generator.at(hlo->operand(0))(index)); TF_ASSIGN_OR_RETURN(llvm::Value * b_or_sigma, @@ -1330,17 +1360,17 @@ StatusOr ElementalIrEmitter::ConvertValueForDistribution( // Perform the division using the float type with the same number of bits // as the raw value to avoid overflow. if (raw_value_size_in_bits == 32) { - elem_value = b_->CreateUIToFP(elem_value, b_->getFloatTy()); - elem_value = b_->CreateFDiv( - elem_value, llvm::ConstantFP::get(b_->getFloatTy(), std::exp2(32))); + elem_value = UIToFP(elem_value, b_->getFloatTy()); + elem_value = FDiv(elem_value, + llvm::ConstantFP::get(b_->getFloatTy(), std::exp2(32))); } else { - elem_value = b_->CreateUIToFP(elem_value, b_->getDoubleTy()); - elem_value = b_->CreateFDiv( + elem_value = UIToFP(elem_value, b_->getDoubleTy()); + elem_value = FDiv( elem_value, llvm::ConstantFP::get(b_->getDoubleTy(), std::exp2(64))); } if (elem_ir_ty != elem_value->getType()) { - elem_value = b_->CreateFPTrunc(elem_value, elem_ir_ty); + elem_value = FPTrunc(elem_value, elem_ir_ty); } } @@ -1348,9 +1378,7 @@ StatusOr ElementalIrEmitter::ConvertValueForDistribution( switch (hlo->random_distribution()) { case RNG_UNIFORM: { if (elem_ir_ty->isFloatingPointTy()) { - return b_->CreateFAdd( - b_->CreateFMul(b_->CreateFSub(b_or_sigma, a_or_mean), elem_value), - a_or_mean); + return FAdd(FMul(FSub(b_or_sigma, a_or_mean), elem_value), a_or_mean); } else { // To generate a uniform random value in [a, b) from a raw random sample // in range [0, 2^N), we let range = b - a and return @@ -1363,22 +1391,21 @@ StatusOr ElementalIrEmitter::ConvertValueForDistribution( // the same cost as if the whole warp were to re-sample. So an // efficient re-sampling implementation on GPU would need to do // nontrivial work to share entropy between threads in the warp. - auto range = b_->CreateSub(b_or_sigma, a_or_mean); - return b_->CreateAdd(a_or_mean, b_->CreateURem(elem_value, range)); + auto range = Sub(b_or_sigma, a_or_mean); + return Add(a_or_mean, URem(elem_value, range)); } } case RNG_NORMAL: { TF_ASSIGN_OR_RETURN( llvm::Value * r, - EmitErfcInv(elem_prim_ty, - b_->CreateFMul(llvm::ConstantFP::get(elem_ir_ty, 2.0), - elem_value))); - return b_->CreateFAdd(b_->CreateFMul(r, b_or_sigma), a_or_mean); + EmitErfcInv(elem_prim_ty, FMul(llvm::ConstantFP::get(elem_ir_ty, 2.0), + elem_value))); + return FAdd(FMul(r, b_or_sigma), a_or_mean); } default: return InvalidArgument( "unhandled distribution %s", - RandomDistribution_Name(hlo->random_distribution()).c_str()); + RandomDistribution_Name(hlo->random_distribution())); } } @@ -1493,8 +1520,7 @@ std::array CalculateSampleValues( // Precondition: the RNG instruction is not fused. llvm_ir::ElementGenerator ElementalIrEmitter::MakePhiloxRngElementGenerator( const HloInstruction* hlo, - const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator) - const { + const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator) { VLOG(3) << "Using philox RNG algorithm"; CHECK(!hlo->IsFused()); // A random number generated by the per module random number generator. @@ -1517,7 +1543,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakePhiloxRngElementGenerator( // Load the global state variable for the Philox RNG algorithm. llvm::GlobalVariable* rng_state_ptr = llvm_ir::GetOrCreateVariableForPhiloxRngState(module_, b_); - llvm::Value* rng_state = b_->CreateLoad(rng_state_ptr, "rng_state_value"); + llvm::Value* rng_state = Load(rng_state_ptr, "rng_state_value"); // Build and return the elemental IR generator to generate a random value for // the element corresponding to the current thread. @@ -1543,8 +1569,8 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakePhiloxRngElementGenerator( // element within the sample. llvm::Value* elems_per_sample_value = llvm::ConstantInt::get(index_ty, elems_per_sample); - llvm::Value* sample_idx = b_->CreateUDiv(elem_idx, elems_per_sample_value); - llvm::Value* elem_offset = b_->CreateURem(elem_idx, elems_per_sample_value); + llvm::Value* sample_idx = UDiv(elem_idx, elems_per_sample_value); + llvm::Value* elem_offset = URem(elem_idx, elems_per_sample_value); std::array counter_values = CalculateSampleValues( sample_idx, hlo_random_value, global_random_number, rng_state, b_); @@ -1552,18 +1578,17 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakePhiloxRngElementGenerator( // Store the four counter_values into the sample_address alloca so we can // load the elem_offset'th one below. for (int idx = 0; idx < 4; ++idx) { - b_->CreateStore(counter_values[idx], - b_->CreateInBoundsGEP(sample_address, b_->getInt32(idx))); + Store(counter_values[idx], + InBoundsGEP(sample_address, b_->getInt32(idx))); } llvm::Type* int64_ty = b_->getInt64Ty(); CHECK(elems_per_sample == 2 || elems_per_sample == 4); llvm::Type* raw_value_ty = elems_per_sample == 2 ? int64_ty : int32_ty; // Retrieve the raw value for the current element from the current sample. - llvm::Value* raw_elem_value = b_->CreateLoad( - b_->CreateInBoundsGEP( - b_->CreatePointerCast(sample_address, raw_value_ty->getPointerTo()), - elem_offset), + llvm::Value* raw_elem_value = Load( + InBoundsGEP(PointerCast(sample_address, raw_value_ty->getPointerTo()), + elem_offset), "raw_elem_value"); return ConvertValueForDistribution(hlo, operand_to_generator, index, @@ -1574,7 +1599,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakePhiloxRngElementGenerator( StatusOr ElementalIrEmitter::EmitElementalSelect( const HloInstruction* hlo, const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, - const llvm_ir::IrArray::Index& index) const { + const llvm_ir::IrArray::Index& index) { TF_ASSIGN_OR_RETURN(llvm::Value * pred_value, operand_to_generator.at(hlo->operand(0))( ElementwiseSourceIndex(index, *hlo, 0))); @@ -1584,14 +1609,14 @@ StatusOr ElementalIrEmitter::EmitElementalSelect( TF_ASSIGN_OR_RETURN(llvm::Value * on_false_value, operand_to_generator.at(hlo->operand(2))( ElementwiseSourceIndex(index, *hlo, 2))); - return Select(b_->CreateTrunc(pred_value, b_->getInt1Ty()), on_true_value, + return Select(Trunc(pred_value, b_->getInt1Ty()), on_true_value, on_false_value); } StatusOr ElementalIrEmitter::EmitElementalClamp( const HloInstruction* hlo, const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, - const llvm_ir::IrArray::Index& index) const { + const llvm_ir::IrArray::Index& index) { TF_ASSIGN_OR_RETURN(llvm::Value * min_value, operand_to_generator.at(hlo->operand(0))( ElementwiseSourceIndex(index, *hlo, 0))); @@ -1610,14 +1635,14 @@ StatusOr ElementalIrEmitter::EmitElementalClamp( max_value, EmitIntegralMax(min_value, arg_value, is_signed), is_signed); } else { return Unimplemented("Clamp unimplemented for %s", - PrimitiveType_Name(prim_type).c_str()); + PrimitiveType_Name(prim_type)); } } StatusOr ElementalIrEmitter::EmitElementalConcatenate( const HloInstruction* hlo, const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, - const llvm_ir::IrArray::Index& target_index) const { + const llvm_ir::IrArray::Index& target_index) { const int64 concat_dim = hlo->dimensions(0); auto source_index = target_index; @@ -1639,9 +1664,9 @@ StatusOr ElementalIrEmitter::EmitElementalConcatenate( } llvm_ir::SetToFirstInsertPoint(exit_block, b_); - llvm::PHINode* output = b_->CreatePHI( - llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(), module_), - hlo->operands().size()); + llvm::PHINode* output = + PHI(llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(), module_), + hlo->operands().size()); auto prior_insert_point = b_->GetInsertPoint(); b_->SetInsertPoint(init_block); @@ -1656,9 +1681,8 @@ StatusOr ElementalIrEmitter::EmitElementalConcatenate( auto concat_dim_size = llvm::ConstantInt::get(source_index[concat_dim]->getType(), operand->shape().dimensions(concat_dim)); - b_->CreateCondBr( - b_->CreateICmpULT(source_index[concat_dim], concat_dim_size), - true_block, false_block); + CondBr(ICmpULT(source_index[concat_dim], concat_dim_size), true_block, + false_block); // Create the terminator of the true block before calling operand // generators, because they require non-degenerate basic blocks. @@ -1671,11 +1695,10 @@ StatusOr ElementalIrEmitter::EmitElementalConcatenate( // Subtract the size of the concat dimension of the current operand // from the source index. b_->SetInsertPoint(false_block); - source_index[concat_dim] = - b_->CreateSub(source_index[concat_dim], concat_dim_size); + source_index[concat_dim] = Sub(source_index[concat_dim], concat_dim_size); } - b_->CreateUnreachable(); + Unreachable(); b_->SetInsertPoint(exit_block, prior_insert_point); return output; } @@ -1683,7 +1706,7 @@ StatusOr ElementalIrEmitter::EmitElementalConcatenate( StatusOr ElementalIrEmitter::EmitElementalDynamicSlice( const HloInstruction* hlo, const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, - const llvm_ir::IrArray::Index& index) const { + const llvm_ir::IrArray::Index& index) { // Emit IR to read dynamic start indices from hlo->operand(1). const HloInstruction* input_hlo = hlo->operand(0); const int64 rank = ShapeUtil::Rank(input_hlo->shape()); @@ -1700,7 +1723,7 @@ StatusOr ElementalIrEmitter::EmitElementalDynamicSlice( // Clamp the start index so that the sliced portion fits in the operand: // start_index = clamp(start_index, 0, operand_dim_size - output_dim_size) - start_index_value = b_->CreateSExtOrTrunc(start_index_value, index_type); + start_index_value = SExtOrTrunc(start_index_value, index_type); int64 largest_valid_start_index = input_hlo->shape().dimensions(i) - hlo->shape().dimensions(i); CHECK_GE(largest_valid_start_index, 0); @@ -1720,7 +1743,7 @@ StatusOr ElementalIrEmitter::EmitElementalDynamicSlice( for (int64 i = 0; i < rank; ++i) { // Emit IR which computes: // input_index = start_index + offset_index - input_index[i] = b_->CreateAdd(slice_start_index[i], index[i]); + input_index[i] = Add(slice_start_index[i], index[i]); } return operand_to_generator.at(input_hlo)(input_index); } @@ -1728,7 +1751,7 @@ StatusOr ElementalIrEmitter::EmitElementalDynamicSlice( StatusOr ElementalIrEmitter::EmitElementalGather( const HloInstruction* hlo, const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, - const llvm_ir::IrArray::Index& index) const { + const llvm_ir::IrArray::Index& index) { const Shape& operand_shape = hlo->operand(0)->shape(); const Shape& indices_shape = hlo->operand(1)->shape(); const Shape& output_shape = hlo->shape(); @@ -1777,7 +1800,7 @@ StatusOr ElementalIrEmitter::EmitElementalGather( auto add_to_operand_index = [&](llvm::Value* index_component, int64 dim) { llvm::Value* gather_dim_component_extended = - b_->CreateSExtOrTrunc(index_component, index_type); + SExtOrTrunc(index_component, index_type); int64 operand_dim = dim_numbers.start_index_map(dim); int64 output_dim = operand_to_output_dim[operand_dim]; // If 'output_dim' is -1, it means 'operand_dim' is an elided window dim. @@ -1801,8 +1824,8 @@ StatusOr ElementalIrEmitter::EmitElementalGather( gather_dim_component_extended, is_signed), is_signed); - operand_index[operand_dim] = b_->CreateAdd( - operand_index[operand_dim], gather_dim_component_extended_inbound); + operand_index[operand_dim] = + Add(operand_index[operand_dim], gather_dim_component_extended_inbound); }; if (indices_shape.dimensions_size() == dim_numbers.index_vector_dim()) { @@ -1826,7 +1849,7 @@ StatusOr ElementalIrEmitter::EmitElementalGather( StatusOr ElementalIrEmitter::EmitElementalDynamicUpdateSlice( const HloInstruction* hlo, const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, - const llvm_ir::IrArray::Index& index) const { + const llvm_ir::IrArray::Index& index) { const HloInstruction* input_hlo = hlo->operand(0); const HloInstruction* update_hlo = hlo->operand(1); const HloInstruction* start_hlo = hlo->operand(2); @@ -1849,7 +1872,7 @@ StatusOr ElementalIrEmitter::EmitElementalDynamicUpdateSlice( // Clamp the start index so that the update region fits in the operand. // start_index = clamp(start_index, 0, input_dim_size - update_dim_size) - start_index_value = b_->CreateSExtOrTrunc(start_index_value, index_type); + start_index_value = SExtOrTrunc(start_index_value, index_type); llvm::Value* update_dim_size = index_typed_const(update_hlo->shape().dimensions(i)); int64 largest_valid_start_index = @@ -1865,14 +1888,14 @@ StatusOr ElementalIrEmitter::EmitElementalDynamicUpdateSlice( start_index_value->setName( AsStringRef(IrName(hlo, StrCat("start_idx", i)))); slice_start_index[i] = start_index_value; - slice_limit_index[i] = b_->CreateAdd(slice_start_index[i], update_dim_size); - - slice_intersection = b_->CreateAnd( - slice_intersection, b_->CreateICmpSGE(index[i], slice_start_index[i]), - "slice_intersection"); - slice_intersection = b_->CreateAnd( - slice_intersection, b_->CreateICmpSLT(index[i], slice_limit_index[i]), - "slice_intersection"); + slice_limit_index[i] = Add(slice_start_index[i], update_dim_size); + + slice_intersection = + And(slice_intersection, ICmpSGE(index[i], slice_start_index[i]), + "slice_intersection"); + slice_intersection = + And(slice_intersection, ICmpSLT(index[i], slice_limit_index[i]), + "slice_intersection"); } // Emit: @@ -1889,26 +1912,26 @@ StatusOr ElementalIrEmitter::EmitElementalDynamicUpdateSlice( // Compute update index for intersection case. llvm_ir::IrArray::Index update_index(index.GetType(), rank); for (int64 i = 0; i < rank; ++i) { - update_index[i] = b_->CreateSub(index[i], slice_start_index[i]); + update_index[i] = Sub(index[i], slice_start_index[i]); } TF_ASSIGN_OR_RETURN(llvm::Value * true_value, operand_to_generator.at(update_hlo)(update_index)); - b_->CreateStore(true_value, ret_value_addr); + Store(true_value, ret_value_addr); // Handle false BB (return data from 'input') SetToFirstInsertPoint(if_data.false_block, b_); TF_ASSIGN_OR_RETURN(llvm::Value * false_value, operand_to_generator.at(input_hlo)(index)); - b_->CreateStore(false_value, ret_value_addr); + Store(false_value, ret_value_addr); SetToFirstInsertPoint(if_data.after_block, b_); - return b_->CreateLoad(ret_value_addr); + return Load(ret_value_addr); } StatusOr ElementalIrEmitter::EmitElementalPad( const HloInstruction* hlo, const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, - const llvm_ir::IrArray::Index& padded_index) const { + const llvm_ir::IrArray::Index& padded_index) { auto index = padded_index; llvm::Value* in_bounds = b_->getTrue(); for (size_t i = 0; i < index.size(); ++i) { @@ -1916,26 +1939,22 @@ StatusOr ElementalIrEmitter::EmitElementalPad( return llvm::ConstantInt::get(index[i]->getType(), n); }; const auto& pad_dim = hlo->padding_config().dimensions(i); - index[i] = - b_->CreateSub(index[i], index_typed_const(pad_dim.edge_padding_low())); - in_bounds = b_->CreateAnd(in_bounds, - b_->CreateICmpSGE(index[i], index_typed_const(0)), - "in_bounds"); - in_bounds = b_->CreateAnd( + index[i] = Sub(index[i], index_typed_const(pad_dim.edge_padding_low())); + in_bounds = + And(in_bounds, ICmpSGE(index[i], index_typed_const(0)), "in_bounds"); + in_bounds = And( in_bounds, - b_->CreateICmpEQ( + ICmpEQ( index_typed_const(0), - b_->CreateURem(index[i], - index_typed_const(pad_dim.interior_padding() + 1))), - "in_bounds"); - index[i] = b_->CreateSDiv( - index[i], index_typed_const(pad_dim.interior_padding() + 1)); - in_bounds = b_->CreateAnd( - in_bounds, - b_->CreateICmpSLT( - index[i], - index_typed_const(hlo->operand(0)->shape().dimensions(i))), + URem(index[i], index_typed_const(pad_dim.interior_padding() + 1))), "in_bounds"); + index[i] = + SDiv(index[i], index_typed_const(pad_dim.interior_padding() + 1)); + in_bounds = + And(in_bounds, + ICmpSLT(index[i], + index_typed_const(hlo->operand(0)->shape().dimensions(i))), + "in_bounds"); } // if (in_bounds) { @@ -1951,26 +1970,26 @@ StatusOr ElementalIrEmitter::EmitElementalPad( SetToFirstInsertPoint(if_data.true_block, b_); TF_ASSIGN_OR_RETURN(llvm::Value * operand_value, operand_to_generator.at(hlo->operand(0))(index)); - b_->CreateStore(operand_value, ret_value_addr); + Store(operand_value, ret_value_addr); SetToFirstInsertPoint(if_data.false_block, b_); TF_ASSIGN_OR_RETURN(llvm::Value * padding_value, operand_to_generator.at(hlo->operand(1))( IrArray::Index(index.GetType()))); - b_->CreateStore(padding_value, ret_value_addr); + Store(padding_value, ret_value_addr); SetToFirstInsertPoint(if_data.after_block, b_); // Don't create phi(operand_value, padding_value) here, because invoking // operand_to_generator may create new basic blocks, making the parent // of operand_value or padding_value no longer a predecessor of // if_data.after_block. - return b_->CreateLoad(ret_value_addr); + return Load(ret_value_addr); } StatusOr ElementalIrEmitter::EmitElementalDot( const HloInstruction* hlo, const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, - const llvm_ir::IrArray::Index& dot_result_index) const { + const llvm_ir::IrArray::Index& dot_result_index) { auto lhs_generator = operand_to_generator.at(hlo->operand(0)); auto rhs_generator = operand_to_generator.at(hlo->operand(1)); @@ -1998,8 +2017,7 @@ StatusOr ElementalIrEmitter::EmitElementalDot( llvm_ir::PrimitiveTypeToIrType(primitive_type, module_); llvm::Value* accumulator_alloca = llvm_ir::EmitAllocaAtFunctionEntry(primitive_type_llvm, "dot_acc", b_); - b_->CreateStore(llvm::Constant::getNullValue(primitive_type_llvm), - accumulator_alloca); + Store(llvm::Constant::getNullValue(primitive_type_llvm), accumulator_alloca); SetToFirstInsertPoint(inner_loop->GetBodyBasicBlock(), b_); @@ -2021,42 +2039,37 @@ StatusOr ElementalIrEmitter::EmitElementalDot( } rhs_index.InsertAt(rhs_contracting_dim, inner_loop->GetIndVarValue()); - llvm::Value* current_accumulator = b_->CreateLoad(accumulator_alloca); + llvm::Value* current_accumulator = Load(accumulator_alloca); TF_ASSIGN_OR_RETURN(llvm::Value * lhs_value, lhs_generator(lhs_index)); TF_ASSIGN_OR_RETURN(llvm::Value * rhs_value, rhs_generator(rhs_index)); llvm::Value* next_accumulator; if (primitive_util::IsComplexType(primitive_type)) { - llvm::Value* product_real = b_->CreateFSub( - b_->CreateFMul(EmitExtractReal(lhs_value), EmitExtractReal(rhs_value)), - b_->CreateFMul(EmitExtractImag(lhs_value), EmitExtractImag(rhs_value))); - llvm::Value* product_imag = b_->CreateFAdd( - b_->CreateFMul(EmitExtractReal(lhs_value), EmitExtractImag(rhs_value)), - b_->CreateFMul(EmitExtractImag(lhs_value), EmitExtractReal(rhs_value))); - next_accumulator = b_->CreateInsertValue( + llvm::Value* product_real = + FSub(FMul(EmitExtractReal(lhs_value), EmitExtractReal(rhs_value)), + FMul(EmitExtractImag(lhs_value), EmitExtractImag(rhs_value))); + llvm::Value* product_imag = + FAdd(FMul(EmitExtractReal(lhs_value), EmitExtractImag(rhs_value)), + FMul(EmitExtractImag(lhs_value), EmitExtractReal(rhs_value))); + next_accumulator = InsertValue( current_accumulator, - b_->CreateFAdd(EmitExtractReal(current_accumulator), product_real), - {0}); - next_accumulator = b_->CreateInsertValue( + FAdd(EmitExtractReal(current_accumulator), product_real), {0}); + next_accumulator = InsertValue( next_accumulator, - b_->CreateFAdd(EmitExtractImag(current_accumulator), product_imag), - {1}); + FAdd(EmitExtractImag(current_accumulator), product_imag), {1}); } else if (primitive_util::IsFloatingPointType(primitive_type)) { - next_accumulator = b_->CreateFAdd(current_accumulator, - b_->CreateFMul(lhs_value, rhs_value)); + next_accumulator = FAdd(current_accumulator, FMul(lhs_value, rhs_value)); } else { - next_accumulator = - b_->CreateAdd(current_accumulator, b_->CreateMul(lhs_value, rhs_value)); + next_accumulator = Add(current_accumulator, Mul(lhs_value, rhs_value)); } - b_->CreateStore(next_accumulator, accumulator_alloca); + Store(next_accumulator, accumulator_alloca); SetToFirstInsertPoint(inner_loop->GetExitBasicBlock(), b_); - return b_->CreateLoad(accumulator_alloca); + return Load(accumulator_alloca); } llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( const HloInstruction* hlo, - const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator) - const { + const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator) { switch (hlo->opcode()) { case HloOpcode::kAbs: case HloOpcode::kRoundNearestAfz: @@ -2150,10 +2163,10 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( const HloInstruction* operand = hlo->operand(0); auto source_index = target_index; for (int64 dim : hlo->dimensions()) { - source_index[dim] = b_->CreateSub( - llvm::ConstantInt::get(target_index[dim]->getType(), - hlo->shape().dimensions(dim) - 1), - target_index[dim]); + source_index[dim] = + Sub(llvm::ConstantInt::get(target_index[dim]->getType(), + hlo->shape().dimensions(dim) - 1), + target_index[dim]); } return operand_to_generator.at(operand)(source_index); }; @@ -2167,6 +2180,61 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( target_index.SourceIndexOfBroadcast(hlo->shape(), operand->shape(), hlo->dimensions(), b_)); }; + case HloOpcode::kIota: + return [this, hlo]( + const IrArray::Index& target_index) -> StatusOr { + auto* iota = Cast(hlo); + PrimitiveType element_type = iota->shape().element_type(); + IrArray::Index elem_index = + ShapeUtil::Rank(iota->shape()) > 1 + ? target_index.SourceIndexOfBroadcast( + iota->shape(), + ShapeUtil::MakeShapeWithDescendingLayout( + element_type, + {iota->shape().dimensions(iota->iota_dimension())}), + {iota->iota_dimension()}, b_) + : target_index; + llvm::Value* elem_index_linear = elem_index.linear(); + if (elem_index_linear == nullptr) { + std::vector iota_bound = { + iota->shape().dimensions(iota->iota_dimension())}; + elem_index_linear = elem_index.Linearize(iota_bound, b_); + } + Shape component_shape = + ShapeUtil::ElementIsComplex(iota->shape()) + ? ShapeUtil::ComplexComponentShape(iota->shape()) + : iota->shape(); + PrimitiveType component_element_type = component_shape.element_type(); + llvm::Value* iota_result; + if (ShapeUtil::ElementIsIntegral(component_shape)) { + iota_result = b_->CreateIntCast( + elem_index_linear, + llvm_ir::PrimitiveTypeToIrType(component_element_type, module_), + /*isSigned=*/false); + } else { + TF_RET_CHECK(ShapeUtil::ElementIsFloating(component_shape)) + << component_element_type; + llvm::Type* float_ir_type; + if (component_element_type == BF16) { + float_ir_type = llvm_ir::PrimitiveTypeToIrType(F32, module_); + } else { + float_ir_type = + llvm_ir::PrimitiveTypeToIrType(component_element_type, module_); + } + llvm::Value* float_val = + b_->CreateUIToFP(elem_index_linear, float_ir_type); + if (component_element_type == BF16) { + iota_result = EmitF32ToBF16(float_val, b_); + } else { + iota_result = float_val; + } + } + if (ShapeUtil::ElementIsComplex(iota->shape())) { + return EmitComposeComplex(iota, iota_result, nullptr); + } else { + return iota_result; + } + }; case HloOpcode::kSlice: return [this, hlo, &operand_to_generator]( const IrArray::Index& index) -> StatusOr { @@ -2232,28 +2300,28 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( default: return [hlo](const IrArray::Index& index) { return Unimplemented("Unhandled opcode for elemental IR emission: %s", - HloOpcodeString(hlo->opcode()).c_str()); + HloOpcodeString(hlo->opcode())); }; } } -llvm::Value* ElementalIrEmitter::EmitExtractReal(llvm::Value* value) const { - return b_->CreateExtractValue(value, {0}); +llvm::Value* ElementalIrEmitter::EmitExtractReal(llvm::Value* value) { + return ExtractValue(value, {0}); } -llvm::Value* ElementalIrEmitter::EmitExtractImag(llvm::Value* value) const { - return b_->CreateExtractValue(value, {1}); +llvm::Value* ElementalIrEmitter::EmitExtractImag(llvm::Value* value) { + return ExtractValue(value, {1}); } llvm::Value* ElementalIrEmitter::EmitComposeComplex(const HloInstruction* op, llvm::Value* real, - llvm::Value* imag) const { + llvm::Value* imag) { auto cplx_type = llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(), module_); - auto complex = b_->CreateInsertValue( - llvm::ConstantAggregateZero::get(cplx_type), real, {0}); + auto complex = + InsertValue(llvm::ConstantAggregateZero::get(cplx_type), real, {0}); if (imag != nullptr) { - complex = b_->CreateInsertValue(complex, imag, {1}); + complex = InsertValue(complex, imag, {1}); } return complex; } diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/elemental_ir_emitter.h index c037b989292216746f3b9b2e620785ce9afb92ad..d3e2acaabd4f602171def70ccd3d4fd5adce0d0d 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.h +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.h @@ -23,12 +23,13 @@ limitations under the License. #include "llvm/IR/Value.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" +#include "tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h" #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h" #include "tensorflow/compiler/xla/statusor.h" namespace xla { -class ElementalIrEmitter { +class ElementalIrEmitter : public IrBuilderMixin { public: using HloToElementGeneratorMap = std::unordered_map; @@ -40,115 +41,114 @@ class ElementalIrEmitter { virtual ~ElementalIrEmitter() = default; virtual StatusOr EmitUnaryOp(const HloInstruction* op, - llvm::Value* operand_value) const; + llvm::Value* operand_value); virtual StatusOr EmitBinaryOp(const HloInstruction* op, llvm::Value* lhs_value, - llvm::Value* rhs_value) const; + llvm::Value* rhs_value); // Returns a function to generate an element of the output of `hlo`, given a // map of functions to generate elements of its operands. virtual llvm_ir::ElementGenerator MakeElementGenerator( const HloInstruction* hlo, - const HloToElementGeneratorMap& operand_to_generator) const; + const HloToElementGeneratorMap& operand_to_generator); - llvm::IRBuilder<>* b() const { return b_; } - llvm::Module* module() const { return module_; } + llvm::IRBuilder<>* b() { return b_; } + + // builder() is for IrBuilderMixin. + llvm::IRBuilder<>* builder() { return b_; } + + llvm::Module* module() { return module_; } protected: - virtual StatusOr EmitIntegerUnaryOp( - const HloInstruction* op, llvm::Value* operand_value) const; + virtual StatusOr EmitIntegerUnaryOp(const HloInstruction* op, + llvm::Value* operand_value); - virtual StatusOr EmitFloatUnaryOp( - const HloInstruction* op, llvm::Value* operand_value) const; + virtual StatusOr EmitFloatUnaryOp(const HloInstruction* op, + llvm::Value* operand_value); - virtual StatusOr EmitComplexUnaryOp( - const HloInstruction* op, llvm::Value* operand_value) const; + virtual StatusOr EmitComplexUnaryOp(const HloInstruction* op, + llvm::Value* operand_value); - llvm::Value* IsZero(llvm::Value* v) const; - llvm::Value* IsIntMinDivisionOverflow(llvm::Value* lhs, - llvm::Value* rhs) const; - llvm::Value* GetZero(llvm::Type* type) const; - llvm::Value* GetOne(llvm::Type* type) const; - llvm::Value* GetIntSMin(llvm::Type* type) const; - llvm::Value* GetMinusOne(llvm::Type* type) const; - llvm::Value* Select(llvm::Value* cond, llvm::Value* if_true, - llvm::Value* if_false) const; + llvm::Value* IsZero(llvm::Value* v); + llvm::Value* IsIntMinDivisionOverflow(llvm::Value* lhs, llvm::Value* rhs); + llvm::Value* GetZero(llvm::Type* type); + llvm::Value* GetOne(llvm::Type* type); + llvm::Value* GetIntSMin(llvm::Type* type); + llvm::Value* GetMinusOne(llvm::Type* type); llvm::Value* EmitIntegerDivide(llvm::Value* lhs, llvm::Value* rhs, - bool is_signed) const; + bool is_signed); llvm::Value* EmitIntegerRemainder(llvm::Value* lhs, llvm::Value* rhs, - bool is_signed) const; + bool is_signed); virtual StatusOr EmitIntegerBinaryOp(const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value, - bool is_signed) const; + bool is_signed); - virtual StatusOr EmitFloatBinaryOp( - const HloInstruction* op, llvm::Value* lhs_value, - llvm::Value* rhs_value) const; + virtual StatusOr EmitFloatBinaryOp(const HloInstruction* op, + llvm::Value* lhs_value, + llvm::Value* rhs_value); - virtual StatusOr EmitComplexBinaryOp( - const HloInstruction* op, llvm::Value* lhs_value, - llvm::Value* rhs_value) const; + virtual StatusOr EmitComplexBinaryOp(const HloInstruction* op, + llvm::Value* lhs_value, + llvm::Value* rhs_value); virtual llvm::Value* EmitFloatMax(llvm::Value* lhs_value, - llvm::Value* rhs_value) const; + llvm::Value* rhs_value); virtual llvm::Value* EmitFloatMin(llvm::Value* lhs_value, - llvm::Value* rhs_value) const; + llvm::Value* rhs_value); llvm::Value* EmitIntegralMax(llvm::Value* lhs_value, llvm::Value* rhs_value, - bool is_signed) const; + bool is_signed); llvm::Value* EmitIntegralMin(llvm::Value* lhs_value, llvm::Value* rhs_value, - bool is_signed) const; + bool is_signed); virtual StatusOr EmitErfInv(PrimitiveType prim_type, - llvm::Value* value) const; + llvm::Value* value); virtual StatusOr EmitErfcInv(PrimitiveType prim_type, - llvm::Value* value) const; + llvm::Value* value); virtual StatusOr EmitAtan2(PrimitiveType prim_type, - llvm::Value* lhs, - llvm::Value* rhs) const; + llvm::Value* lhs, llvm::Value* rhs); virtual StatusOr EmitLog(PrimitiveType prim_type, - llvm::Value* value) const; + llvm::Value* value); virtual StatusOr EmitLog1p(PrimitiveType prim_type, - llvm::Value* value) const; + llvm::Value* value); virtual StatusOr EmitSin(PrimitiveType prim_type, - llvm::Value* value) const; + llvm::Value* value); virtual StatusOr EmitCos(PrimitiveType prim_type, - llvm::Value* value) const; + llvm::Value* value); virtual StatusOr EmitExp(PrimitiveType prim_type, - llvm::Value* value) const; + llvm::Value* value); virtual StatusOr EmitExpm1(PrimitiveType prim_type, - llvm::Value* value) const; + llvm::Value* value); virtual StatusOr EmitPow(PrimitiveType prim_type, - llvm::Value* lhs, - llvm::Value* rhs) const; + llvm::Value* lhs, llvm::Value* rhs); virtual StatusOr EmitTanh(PrimitiveType prim_type, - llvm::Value* value) const; + llvm::Value* value); virtual StatusOr EmitReducePrecision(const HloInstruction* hlo, - llvm::Value* x) const; + llvm::Value* x); - virtual llvm::Value* EmitExtractReal(llvm::Value* value) const; - virtual llvm::Value* EmitExtractImag(llvm::Value* value) const; + virtual llvm::Value* EmitExtractReal(llvm::Value* value); + virtual llvm::Value* EmitExtractImag(llvm::Value* value); // Composes a complex struct. imag may be nullptr for simple cast operations. llvm::Value* EmitComposeComplex(const HloInstruction* op, llvm::Value* real, - llvm::Value* imag) const; + llvm::Value* imag); // A helper method for MakeElementGenerator. Given an elementwise op `hlo` and // the target array index, computes the source array index of its @@ -157,50 +157,50 @@ class ElementalIrEmitter { // Precondition: `hlo` is an elementwise op. llvm_ir::IrArray::Index ElementwiseSourceIndex( const llvm_ir::IrArray::Index& target_index, const HloInstruction& hlo, - int64 operand_no) const; + int64 operand_no); // Identifier of the thread unique among all threads on the device - virtual llvm::Value* EmitThreadId() const { return b_->getIntN(128, 0); } + virtual llvm::Value* EmitThreadId() { return b_->getIntN(128, 0); } StatusOr EmitElementalSelect( const HloInstruction* hlo, const HloToElementGeneratorMap& operand_to_generator, - const llvm_ir::IrArray::Index& index) const; + const llvm_ir::IrArray::Index& index); StatusOr EmitElementalClamp( const HloInstruction* hlo, const HloToElementGeneratorMap& operand_to_generator, - const llvm_ir::IrArray::Index& index) const; + const llvm_ir::IrArray::Index& index); StatusOr EmitElementalConcatenate( const HloInstruction* hlo, const HloToElementGeneratorMap& operand_to_generator, - const llvm_ir::IrArray::Index& target_index) const; + const llvm_ir::IrArray::Index& target_index); StatusOr EmitElementalDynamicSlice( const HloInstruction* hlo, const HloToElementGeneratorMap& operand_to_generator, - const llvm_ir::IrArray::Index& index) const; + const llvm_ir::IrArray::Index& index); StatusOr EmitElementalGather( const HloInstruction* hlo, const HloToElementGeneratorMap& operand_to_generator, - const llvm_ir::IrArray::Index& index) const; + const llvm_ir::IrArray::Index& index); StatusOr EmitElementalDynamicUpdateSlice( const HloInstruction* hlo, const HloToElementGeneratorMap& operand_to_generator, - const llvm_ir::IrArray::Index& index) const; + const llvm_ir::IrArray::Index& index); StatusOr EmitElementalPad( const HloInstruction* hlo, const HloToElementGeneratorMap& operand_to_generator, - const llvm_ir::IrArray::Index& padded_index) const; + const llvm_ir::IrArray::Index& padded_index); StatusOr EmitElementalDot( const HloInstruction* hlo, const HloToElementGeneratorMap& operand_to_generator, - const llvm_ir::IrArray::Index& dot_result_index) const; + const llvm_ir::IrArray::Index& dot_result_index); llvm::IRBuilder<>* const b_; @@ -215,13 +215,13 @@ class ElementalIrEmitter { // random number generation algorithm. llvm_ir::ElementGenerator MakePhiloxRngElementGenerator( const HloInstruction* hlo, - const HloToElementGeneratorMap& operand_to_generator) const; + const HloToElementGeneratorMap& operand_to_generator); // Converts the raw value generated by a random number generation algorithm // to the distribution requested by the RNG HloInstruction. StatusOr ConvertValueForDistribution( const HloInstruction* hlo, const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, - const llvm_ir::IrArray::Index& index, llvm::Value* raw_value) const; + const llvm_ir::IrArray::Index& index, llvm::Value* raw_value); }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc index 5ab07562194a305b2e020befaaf62fedc1c87d7e..852f34e06df35242b13110ae4411b8c969c26019 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc @@ -28,8 +28,7 @@ using absl::nullopt; class ElementalIrEmitterExecutionTest : public HloTestBase { protected: - void RunTest(const string& hlo_text, - tensorflow::gtl::ArraySlice args) { + void RunTest(const string& hlo_text, absl::Span args) { HloModuleConfig config; config.set_debug_options(GetDebugOptionsForTest()); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, @@ -57,9 +56,9 @@ ENTRY main { } )"; - std::unique_ptr lhs = LiteralUtil::CreateR3({{{1}, {2}}}); - std::unique_ptr rhs = LiteralUtil::CreateR3({{{3}, {4}}}); - RunTest(hlo_text, {lhs.get(), rhs.get()}); + Literal lhs = LiteralUtil::CreateR3({{{1}, {2}}}); + Literal rhs = LiteralUtil::CreateR3({{{3}, {4}}}); + RunTest(hlo_text, {&lhs, &rhs}); } } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/executable.cc b/tensorflow/compiler/xla/service/executable.cc index 1c9f396b68fa20a03986d81d642d1726b26cd0dc..47c56e2f7fbd9f53be6a2b189c5c36cf4fdcdccb 100644 --- a/tensorflow/compiler/xla/service/executable.cc +++ b/tensorflow/compiler/xla/service/executable.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/executable.h" #include "absl/memory/memory.h" +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" #include "tensorflow/compiler/xla/status.h" @@ -23,16 +24,14 @@ limitations under the License. #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/strings/proto_serialization.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/env.h" -using tensorflow::gtl::ArraySlice; namespace xla { StatusOr> Executable::ExecuteOnStreams( - ArraySlice run_options, - ArraySlice> arguments) { + absl::Span run_options, + absl::Span> arguments) { TF_RET_CHECK(run_options.size() == arguments.size()); std::vector return_values; @@ -63,7 +62,7 @@ StatusOr> Executable::ExecuteOnStreams( StatusOr Executable::ExecuteOnStreamWrapper( const ServiceExecutableRunOptions* run_options, ExecutionProfile* profile, - ArraySlice arguments) { + absl::Span arguments) { se::Stream* stream = run_options->stream(); std::unique_ptr timer; if (profile != nullptr) { @@ -155,9 +154,9 @@ Status Executable::DumpHloSnapshot() { const string& directory_path = module_config().debug_options().xla_dump_executions_to(); const auto& module = hlo_snapshot_->hlo().hlo_module(); - string filename = tensorflow::strings::Printf( - "computation_%lld__%s__execution_%lld", module.id(), - module.entry_computation_name().c_str(), ++execution_count_); + string filename = + absl::StrFormat("computation_%d__%s__execution_%d", module.id(), + module.entry_computation_name(), ++execution_count_); return Executable::DumpToDirectory(directory_path, filename, *hlo_snapshot_); } diff --git a/tensorflow/compiler/xla/service/executable.h b/tensorflow/compiler/xla/service/executable.h index 98eaeee30a693211ae564a5ef3c373f0364bef11..3a6780f2a67f230cae626ea00cfbf93b4e60d968 100644 --- a/tensorflow/compiler/xla/service/executable.h +++ b/tensorflow/compiler/xla/service/executable.h @@ -18,7 +18,10 @@ limitations under the License. #include #include +#include +#include "absl/types/span.h" +#include "absl/types/variant.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" @@ -26,18 +29,33 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_execution_profile.h" #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" #include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h" +#include "tensorflow/compiler/xla/service/owning_device_memory.h" #include "tensorflow/compiler/xla/service/service_executable_run_options.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" +#include "tensorflow/compiler/xla/shape_tree.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/thread_annotations.h" namespace xla { +// ExecutionOutput encapsulates the output buffers of a execution and the +// leftover buffers to be released by the caller. +struct ExecutionOutput { + ExecutionOutput(ScopedShapedBuffer result, + std::vector to_be_released) + : result(std::move(result)), to_be_released(std::move(to_be_released)) {} + ScopedShapedBuffer result; + + // Leftover buffers for the caller to release. Elements in this list are + // donated input memory buffers that are not reused by XLA as outputs. + std::vector to_be_released; +}; + // A given platform's compiler will produce an Executable -- this is a uniform // interface that is used for launching compiled programs across platforms. class Executable { @@ -63,25 +81,46 @@ class Executable { // Returns a shaped buffer containing the result of the computation. virtual StatusOr ExecuteOnStream( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, HloExecutionProfile* hlo_execution_profile) = 0; // Same as ExecuteOnStream(), but this call is non-blocking and returns as // soon as all of the operations are enqueued for launch on the stream. virtual StatusOr ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments) = 0; + absl::Span arguments) = 0; + + // Starts the given program executing on the given stream/executor. + // + // `arguments` are ShapeTree containing the input parameters. For each element + // in the shape tree, if the element holds the ownership of the memory, it is + // considered donated and XLA will potentially reuse it as output buffers. For + // all donated inputs, XLA is also responsible for freeing them. + // + // If an input is donated to XLA but is not reused as output, it is returned + // as an leftover buffer for the caller to release. + virtual StatusOr ExecuteOnStream( + const ServiceExecutableRunOptions* run_options, + std::vector> arguments, + HloExecutionProfile* hlo_execution_profile) { + return Unimplemented( + "MaybeOwningDeviceMemory version of overload is not implemented "); + } + + virtual StatusOr ExecuteAsyncOnStream( + const ServiceExecutableRunOptions* run_options, + std::vector> arguments) { + return Unimplemented( + "MaybeOwningDeviceMemory version of overload is not implemented "); + } // Same as ExecuteOnStream(), but runs this executable on multiple // streams. arguments[i] contains the arguments to the execution on // run_options[i]->stream() and the returned value is at index i of the // returned vector. virtual StatusOr> ExecuteOnStreams( - tensorflow::gtl::ArraySlice - run_options, - tensorflow::gtl::ArraySlice< - tensorflow::gtl::ArraySlice> - arguments); + absl::Span run_options, + absl::Span> arguments); // Populates `hlo_execution_profile` from `executor`. This is implicit in any // Execute* API call that takes a hlo_execution_profile argument, but must be @@ -97,7 +136,7 @@ class Executable { // given ExecutionProfile if non-null. StatusOr ExecuteOnStreamWrapper( const ServiceExecutableRunOptions* run_options, ExecutionProfile* profile, - tensorflow::gtl::ArraySlice arguments); + absl::Span arguments); // Returns the ExecutionProfile from executing on the device. This includes // the number of cycles taken for the computation or the compilation time. diff --git a/tensorflow/compiler/xla/service/execution_tracker.cc b/tensorflow/compiler/xla/service/execution_tracker.cc index 70a78c8a2b6f3cf360ca2ac7255f8dc35235125e..997db7c058af6da8ecff399769b85b803e2e5785 100644 --- a/tensorflow/compiler/xla/service/execution_tracker.cc +++ b/tensorflow/compiler/xla/service/execution_tracker.cc @@ -66,7 +66,7 @@ Status ExecutionTracker::Unregister(const ExecutionHandle& handle) { tensorflow::mutex_lock lock(execution_mutex_); auto it = handle_to_execution_.find(handle.handle()); if (it == handle_to_execution_.end()) { - return NotFound("no execution record for execution handle: %lld", + return NotFound("no execution record for execution handle: %d", handle.handle()); } handle_to_execution_.erase(handle.handle()); @@ -78,7 +78,7 @@ StatusOr ExecutionTracker::Resolve( tensorflow::mutex_lock lock(execution_mutex_); auto it = handle_to_execution_.find(handle.handle()); if (it == handle_to_execution_.end()) { - return NotFound("no execution record for execution handle: %lld", + return NotFound("no execution record for execution handle: %d", handle.handle()); } return it->second.get(); diff --git a/tensorflow/compiler/xla/service/flatten_call_graph.h b/tensorflow/compiler/xla/service/flatten_call_graph.h index 3cccec9862e0f92df478006939552099868121b9..986970f8862472d1db7564254a9c1277750bb6eb 100644 --- a/tensorflow/compiler/xla/service/flatten_call_graph.h +++ b/tensorflow/compiler/xla/service/flatten_call_graph.h @@ -26,7 +26,7 @@ namespace xla { // Flattening associates each call site with a unique computation (for // sequential calling contexts) This simplifies buffer assignment and // points-to analysis (see b/36865746 for details). -class FlattenCallGraph : public HloPassInterface { +class FlattenCallGraph : public HloModulePass { public: absl::string_view name() const override { return "flatten-call-graph"; } diff --git a/tensorflow/compiler/xla/service/flatten_call_graph_test.cc b/tensorflow/compiler/xla/service/flatten_call_graph_test.cc index 8f6608241ed02bbb7e9fde9b6d767c002435e777..5fbd73a5363b4cdbcaafedbe6f4e7bd6bb2a92d8 100644 --- a/tensorflow/compiler/xla/service/flatten_call_graph_test.cc +++ b/tensorflow/compiler/xla/service/flatten_call_graph_test.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -30,7 +30,7 @@ limitations under the License. namespace xla { namespace { -class FlattenCallGraphTest : public HloTestBase { +class FlattenCallGraphTest : public HloVerifiedTestBase { protected: // Build and return a trivial computation taking and returning a scalar. std::unique_ptr MakeScalarComputation() { @@ -139,9 +139,9 @@ TEST_F(FlattenCallGraphTest, ComplexGraph) { } { - TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module)); EXPECT_TRUE(result); - std::unique_ptr flat_call_graph = CallGraph::Build(module.get()); + std::unique_ptr flat_call_graph = CallGraph::Build(module); const CallGraphNode& c_node = flat_call_graph->GetNode(c_computation); EXPECT_EQ(1, c_node.caller_callsites().size()); } @@ -176,15 +176,15 @@ TEST_F(FlattenCallGraphTest, SharedWhileConditionAndBody) { } { - std::unique_ptr call_graph = CallGraph::Build(module.get()); + std::unique_ptr call_graph = CallGraph::Build(module); const CallGraphNode& cond_node = call_graph->GetNode(cond_computation); EXPECT_EQ(2, cond_node.caller_callsites().size()); } { - TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module)); EXPECT_TRUE(result); - std::unique_ptr call_graph = CallGraph::Build(module.get()); + std::unique_ptr call_graph = CallGraph::Build(module); const CallGraphNode& cond_node = call_graph->GetNode(cond_computation); EXPECT_EQ(1, cond_node.caller_callsites().size()); } @@ -211,9 +211,9 @@ TEST_F(FlattenCallGraphTest, FlattenCalls) { module->AddEntryComputation( MakeCallingComputation(b_computation, /*callsites=*/2, ".Entry")); - TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module)); EXPECT_TRUE(result); - std::unique_ptr call_graph = CallGraph::Build(module.get()); + std::unique_ptr call_graph = CallGraph::Build(module); EXPECT_EQ(7, module->computation_count()); const CallGraphNode& c_node = call_graph->GetNode(c_computation); @@ -243,9 +243,9 @@ TEST_F(FlattenCallGraphTest, FlattenCallsInConditional) { module->AddEntryComputation(builder.Build()); EXPECT_EQ(2, module->computation_count()); - TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module)); EXPECT_TRUE(result); - std::unique_ptr call_graph = CallGraph::Build(module.get()); + std::unique_ptr call_graph = CallGraph::Build(module); // The true and false computations must now be different. EXPECT_EQ(3, module->computation_count()); diff --git a/tensorflow/compiler/xla/service/fusion_queue.h b/tensorflow/compiler/xla/service/fusion_queue.h new file mode 100644 index 0000000000000000000000000000000000000000..1208a7dda87d7b2a6ad7113e2604e8b9a0fa045b --- /dev/null +++ b/tensorflow/compiler/xla/service/fusion_queue.h @@ -0,0 +1,53 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_FUSION_QUEUE_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_FUSION_QUEUE_H_ + +#include + +#include "tensorflow/compiler/xla/service/hlo_instruction.h" + +namespace xla { + +// A queue interface that allows implementations to choose fusion candidates in +// custom order. +class FusionQueue { + public: + FusionQueue() = default; + virtual ~FusionQueue() = default; + + // Dequeues the next fusion candidates: a consumer and the list of producers + // as operand indices. + virtual std::pair> + DequeueNextInstructionAndOperandsToFuseInOrder() = 0; + + // A callback passed to the queue implementation right before the producer is + // fused into the consumer. + virtual void PreFusion(HloInstruction* producer, HloInstruction* consumer) {} + + // A callback passed to the queue implementation right after the fusion is + // created. Note that original_producer could have been destroyed. + virtual void OnFusingInstruction(HloInstruction* fusion, + HloInstruction* original_producer, + HloInstruction* original_consumer) {} + + // A callback passed to the queue implementation to notify the removal of an + // instruction. + virtual void RemoveInstruction(HloInstruction* instruction) = 0; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_FUSION_QUEUE_H_ diff --git a/tensorflow/compiler/xla/service/gather_expander.cc b/tensorflow/compiler/xla/service/gather_expander.cc index d889fd8e88ed4008749c116314e9a0c54e6fa63d..cb86c9857936f21d9d2ac6bc22c725b89cca6482 100644 --- a/tensorflow/compiler/xla/service/gather_expander.cc +++ b/tensorflow/compiler/xla/service/gather_expander.cc @@ -25,7 +25,6 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" namespace xla { -using tensorflow::gtl::ArraySlice; static StatusOr TransposeIndexVectorDimToLast( HloInstruction* start_indices, int64 index_vector_dim) { @@ -225,7 +224,7 @@ static StatusOr> GatherLoopBody( static StatusOr CreateGatherLoopAccumulatorInitValue( HloComputation* computation, PrimitiveType element_type, - ArraySlice slice_sizes, int64 gather_loop_trip_count, + absl::Span slice_sizes, int64 gather_loop_trip_count, const GatherDimensionNumbers& dim_numbers) { std::vector accumulator_state_shape_dims; accumulator_state_shape_dims.reserve(1 + slice_sizes.size()); @@ -244,7 +243,7 @@ static StatusOr CreateGatherLoopAccumulatorInitValue( // are the major dimensions and the offset dimensions are the minor dimensions. // Fix this up with a transpose. static StatusOr PermuteBatchAndOffsetDims( - HloInstruction* accumulator, ArraySlice offset_dims, + HloInstruction* accumulator, absl::Span offset_dims, int64 output_rank) { std::vector permutation; permutation.reserve(output_rank); @@ -323,7 +322,7 @@ StatusOr GatherExpander::ExpandGather( return Unimplemented( "Gather operations with more than 2147483647 gather indices are not " "supported. This error occurred for %s.", - gather_instr->ToString().c_str()); + gather_instr->ToString()); } TF_ASSIGN_OR_RETURN( diff --git a/tensorflow/compiler/xla/service/gather_expander.h b/tensorflow/compiler/xla/service/gather_expander.h index 7bd9ea598417a931d2df507d472c6a60be05e0bc..2b39359aae9fc01f1a88a2594108b2772788e826 100644 --- a/tensorflow/compiler/xla/service/gather_expander.h +++ b/tensorflow/compiler/xla/service/gather_expander.h @@ -23,7 +23,7 @@ namespace xla { // This pass rewrites gather operations into (roughly) while loops of dynamic // slices. This lets backends that don't support gather directly to // nevertheless have a minimum level of support. -class GatherExpander : public HloPassInterface { +class GatherExpander : public HloModulePass { public: absl::string_view name() const override { return "gather_expander"; } StatusOr Run(HloModule* module) override; diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.cc b/tensorflow/compiler/xla/service/generic_transfer_manager.cc index 0ce2db907b643f3beabd127388370dbe601179e1..bec02e14f951c6d905b7329be5c02896984279d0 100644 --- a/tensorflow/compiler/xla/service/generic_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/generic_transfer_manager.cc @@ -42,8 +42,7 @@ se::Platform::Id GenericTransferManager::PlatformId() const { } Status GenericTransferManager::WriteSingleTupleIndexTable( - se::Stream* stream, - tensorflow::gtl::ArraySlice elements, + se::Stream* stream, absl::Span elements, const Shape& shape, se::DeviceMemoryBase* region) { TF_RET_CHECK(elements.size() == ShapeUtil::TupleElementCount(shape)); @@ -126,7 +125,7 @@ Status GenericTransferManager::TransferLiteralToDeviceAsync( device_memory.size()); // Element is array-shaped: transfer array data to device buffer. const auto subliteral = LiteralSlice(literal, index); - std::unique_ptr relayed_out_literal; + Literal relayed_out_literal; const void* source; if (LayoutUtil::Equal(device_subshape.layout(), subliteral.shape().layout())) { @@ -139,7 +138,7 @@ Status GenericTransferManager::TransferLiteralToDeviceAsync( // Relayout data before transferring. relayed_out_literal = subliteral.Relayout(device_subshape.layout(), /*shape_index=*/{}); - source = relayed_out_literal->untyped_data(); + source = relayed_out_literal.untyped_data(); TF_RETURN_IF_ERROR(TransferBufferToDevice( stream, /*size=*/GetByteSizeRequirement(device_subshape), source, @@ -163,7 +162,7 @@ Status GenericTransferManager::TransferLiteralFromOutfeed( } Status GenericTransferManager::ResetDevices( - tensorflow::gtl::ArraySlice + absl::Span /*executors*/) { return Unimplemented( "Device reset is not yet supported on this platform (b/30481585)"); diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.h b/tensorflow/compiler/xla/service/generic_transfer_manager.h index 6c1a21587a7ef5199afb93715dc57be5139fbc22..86c8b1c145a25149a25e7b272babc5c858d476af 100644 --- a/tensorflow/compiler/xla/service/generic_transfer_manager.h +++ b/tensorflow/compiler/xla/service/generic_transfer_manager.h @@ -55,15 +55,13 @@ class GenericTransferManager : public TransferManager { const Shape& literal_shape, MutableBorrowingLiteral literal) override; - Status ResetDevices( - tensorflow::gtl::ArraySlice executors) override; + Status ResetDevices(absl::Span executors) override; int64 GetByteSizeRequirement(const Shape& shape) const override; protected: Status WriteSingleTupleIndexTable( - se::Stream* stream, - tensorflow::gtl::ArraySlice elements, + se::Stream* stream, absl::Span elements, const Shape& shape, se::DeviceMemoryBase* region) override; private: diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index e53f525517f7cfd49b0ba66693c319ca5d33b17f..4eb5739fe27d228c4d8939c429665f5d50a6e219 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -57,6 +57,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings:str_format", ], ) @@ -67,9 +68,7 @@ cc_library( # srcs = [ # "partition_assignment_test.cc", # ], -# tags = [ -# "requires-gpu-sm35", -# ], +# tags = tf_cuda_tests_tags(), # deps = [ # ":partition_assignment", # "//tensorflow/core:stream_executor_no_cuda", @@ -92,6 +91,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_reachability", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", ], ) @@ -107,9 +107,12 @@ tf_cc_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings:str_format", ], ) @@ -130,6 +133,7 @@ cc_library( "//tensorflow/compiler/xla/service/llvm_ir:tuple_ops", "//tensorflow/core:lib", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", "@llvm//:core", ], ) @@ -150,7 +154,7 @@ cc_library( deps = [ ":backend_configs", ":buffer_allocations", - ":cudnn_convolution_runner", + ":cudnn_conv_runner", ":elemental_ir_emitter", ":gpu_constants", ":gpu_executable", @@ -169,12 +173,14 @@ cc_library( "//tensorflow/compiler/xla/service:buffer_assignment", "//tensorflow/compiler/xla/service:elemental_ir_emitter", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_casting_utils", "//tensorflow/compiler/xla/service:name_uniquer", "//tensorflow/compiler/xla/service:while_loop_analysis", "//tensorflow/compiler/xla/service/llvm_ir:buffer_assignment_util", "//tensorflow/compiler/xla/service/llvm_ir:dynamic_update_slice_util", "//tensorflow/compiler/xla/service/llvm_ir:fused_ir_emitter", "//tensorflow/compiler/xla/service/llvm_ir:ir_array", + "//tensorflow/compiler/xla/service/llvm_ir:ir_builder_mixin", "//tensorflow/compiler/xla/service/llvm_ir:kernel_support_library", "//tensorflow/compiler/xla/service/llvm_ir:kernel_tiling", "//tensorflow/compiler/xla/service/llvm_ir:llvm_loop", @@ -189,6 +195,7 @@ cc_library( "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", "@llvm//:core", "@llvm//:support", ], @@ -234,6 +241,7 @@ cc_library( "//tensorflow/compiler/xla/service/llvm_ir:math_ops", "//tensorflow/core:lib", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", "@llvm//:core", "@llvm//:support", ], @@ -254,6 +262,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "@com_google_absl//absl/memory", + "@com_google_absl//absl/types:span", ], ) @@ -314,7 +323,7 @@ cc_library( ], deps = [ ":buffer_allocations", - ":cudnn_convolution_runner", + ":cudnn_conv_runner", ":hlo_execution_profiler", ":infeed_manager", ":ir_emission_utils", @@ -349,9 +358,12 @@ cc_library( "//tensorflow/core/platform/default/build_config:cufft_plugin", "//tensorflow/core/platform/default/build_config:stream_executor_cuda", # build_cleaner: keep "//tensorflow/stream_executor", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", ], ) @@ -360,6 +372,7 @@ cc_library( srcs = ["ir_emission_utils.cc"], hdrs = ["ir_emission_utils.h"], deps = [ + ":backend_configs", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:window_util", @@ -372,32 +385,37 @@ cc_library( ) cc_library( - name = "cudnn_convolution_algorithm_picker", - srcs = ["cudnn_convolution_algorithm_picker.cc"], - hdrs = ["cudnn_convolution_algorithm_picker.h"], + name = "cudnn_conv_algorithm_picker", + srcs = ["cudnn_conv_algorithm_picker.cc"], + hdrs = ["cudnn_conv_algorithm_picker.h"], deps = [ ":backend_configs", ":buffer_comparator", - ":cudnn_convolution_runner", + ":cudnn_conv_runner", ":gpu_executable", ":ir_emission_utils", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla/service:compiler", "//tensorflow/compiler/xla/service:device_memory_allocator", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_casting_utils", "//tensorflow/compiler/xla/service:hlo_pass", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/time", "@com_google_absl//absl/types:optional", ], ) cc_library( - name = "cudnn_convolution_runner", - srcs = ["cudnn_convolution_runner.cc"], - hdrs = ["cudnn_convolution_runner.h"], + name = "cudnn_conv_runner", + srcs = ["cudnn_conv_runner.cc"], + hdrs = ["cudnn_conv_runner.h"], deps = [ + ":backend_configs", + ":ir_emission_utils", ":stream_executor_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status", @@ -406,16 +424,19 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo", "//tensorflow/core:stream_executor_no_cuda", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", ], ) cc_library( - name = "cudnn_convolution_rewriter", - srcs = ["cudnn_convolution_rewriter.cc"], - hdrs = ["cudnn_convolution_rewriter.h"], + name = "cudnn_conv_rewriter", + srcs = ["cudnn_conv_rewriter.cc"], + hdrs = ["cudnn_conv_rewriter.h"], deps = [ + ":backend_configs", ":ir_emission_utils", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:util", @@ -428,17 +449,17 @@ cc_library( ) tf_cc_test( - name = "cudnn_convolution_rewriter_test", - srcs = ["cudnn_convolution_rewriter_test.cc"], + name = "cudnn_conv_rewriter_test", + srcs = ["cudnn_conv_rewriter_test.cc"], deps = [ - ":cudnn_convolution_rewriter", + ":cudnn_conv_rewriter", ":ir_emission_utils", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_matchers", "//tensorflow/compiler/xla/service:shape_inference", - "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep "//tensorflow/core:test", ], @@ -449,12 +470,14 @@ cc_library( srcs = ["instruction_fusion.cc"], hdrs = ["instruction_fusion.h"], deps = [ + ":gpu_fusible", ":ir_emission_utils", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:instruction_fusion", "//tensorflow/compiler/xla/service:pattern_matcher", + "@com_google_absl//absl/container:flat_hash_set", ], ) @@ -469,6 +492,7 @@ tf_cc_test( "//tensorflow/compiler/xla/service:hlo_matchers", "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], ) @@ -478,6 +502,7 @@ cc_library( srcs = ["multi_output_fusion.cc"], hdrs = ["multi_output_fusion.h"], deps = [ + ":gpu_fusible", ":instruction_fusion", ":ir_emission_utils", "//tensorflow/compiler/xla:shape_util", @@ -485,6 +510,7 @@ cc_library( "//tensorflow/compiler/xla/service:multi_output_fusion", "//tensorflow/core:lib", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", ], ) @@ -518,6 +544,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_dataflow_analysis", "//tensorflow/compiler/xla/service:hlo_pass", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_set", ], ) @@ -526,6 +553,7 @@ cc_library( srcs = ["fusion_merger.cc"], hdrs = ["fusion_merger.h"], deps = [ + ":gpu_fusible", ":instruction_fusion", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", @@ -553,9 +581,9 @@ tf_cc_test( ) cc_library( - name = "pad_insertion", - srcs = ["pad_insertion.cc"], - hdrs = ["pad_insertion.h"], + name = "cudnn_conv_padding_legalization", + srcs = ["cudnn_conv_padding_legalization.cc"], + hdrs = ["cudnn_conv_padding_legalization.h"], deps = [ ":ir_emission_utils", "//tensorflow/compiler/xla:literal", @@ -563,6 +591,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo_casting_utils", "//tensorflow/compiler/xla/service:hlo_creation_utils", "//tensorflow/compiler/xla/service:hlo_pass", "//tensorflow/compiler/xla/service:shape_inference", @@ -571,28 +600,25 @@ cc_library( ) cc_library( - name = "pad_for_tensor_cores", - srcs = ["pad_for_tensor_cores.cc"], - hdrs = ["pad_for_tensor_cores.h"], + name = "cudnn_conv_pad_for_tensor_cores", + srcs = ["cudnn_conv_pad_for_tensor_cores.cc"], + hdrs = ["cudnn_conv_pad_for_tensor_cores.h"], deps = [ ":ir_emission_utils", - "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:window_util", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/service:hlo_creation_utils", + "//tensorflow/compiler/xla/service:hlo_casting_utils", "//tensorflow/compiler/xla/service:hlo_pass", - "//tensorflow/compiler/xla/service:shape_inference", ], ) tf_cc_test( - name = "pad_for_tensor_cores_test", - srcs = ["pad_for_tensor_cores_test.cc"], + name = "cudnn_conv_pad_for_tensor_cores_test", + srcs = ["cudnn_conv_pad_for_tensor_cores_test.cc"], deps = [ + ":cudnn_conv_pad_for_tensor_cores", ":ir_emission_utils", - ":pad_for_tensor_cores", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/service:hlo_matchers", @@ -634,21 +660,22 @@ cc_library( srcs = ["nvptx_compiler.cc"], hdrs = ["nvptx_compiler.h"], deps = [ - ":cudnn_convolution_algorithm_picker", - ":cudnn_convolution_rewriter", + ":cudnn_conv_algorithm_picker", + ":cudnn_conv_pad_for_tensor_cores", + ":cudnn_conv_padding_legalization", + ":cudnn_conv_rewriter", + ":cudnn_fused_conv_rewriter", ":fusion_merger", ":gpu_constants", ":gpu_copy_insertion", ":gpu_executable", + ":gpu_hlo_schedule", ":gpu_hlo_support_checker", ":gpu_layout_assignment", - ":hlo_schedule", ":instruction_fusion", ":ir_emission_utils", ":ir_emitter", ":multi_output_fusion", - ":pad_for_tensor_cores", - ":pad_insertion", ":partition_assignment", ":stream_assignment", ":stream_executor_util", @@ -663,7 +690,6 @@ cc_library( "//tensorflow/compiler/xla/service:buffer_liveness", "//tensorflow/compiler/xla/service:call_inliner", "//tensorflow/compiler/xla/service:conditional_simplifier", - "//tensorflow/compiler/xla/service:convolution_feature_group_converter", "//tensorflow/compiler/xla/service:executable", "//tensorflow/compiler/xla/service:flatten_call_graph", "//tensorflow/compiler/xla/service:hlo", @@ -680,7 +706,6 @@ cc_library( "//tensorflow/compiler/xla/service:llvm_compiler", "//tensorflow/compiler/xla/service:reduce_precision_insertion", "//tensorflow/compiler/xla/service:reshape_mover", - "//tensorflow/compiler/xla/service:scatter_expander", "//tensorflow/compiler/xla/service:transpose_folding", "//tensorflow/compiler/xla/service:tuple_simplifier", "//tensorflow/compiler/xla/service:while_loop_constant_sinking", @@ -694,9 +719,11 @@ cc_library( "//tensorflow/core:lib_internal", "//tensorflow/core:regexp_internal", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/container:node_hash_map", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", "@llvm//:core", ], alwayslink = True, # Contains compiler registration @@ -754,7 +781,6 @@ cc_library( srcs = ["gpu_layout_assignment.cc"], hdrs = ["gpu_layout_assignment.h"], deps = [ - ":gpu_options", ":ir_emission_utils", ":stream_executor_util", "//tensorflow/compiler/xla:shape_util", @@ -763,6 +789,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:computation_layout", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_casting_utils", "//tensorflow/compiler/xla/service:layout_assignment", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", @@ -789,36 +816,39 @@ tf_cc_test( ) cc_library( - name = "hlo_schedule", - srcs = ["hlo_schedule.cc"], - hdrs = ["hlo_schedule.h"], + name = "gpu_hlo_schedule", + srcs = ["gpu_hlo_schedule.cc"], + hdrs = ["gpu_hlo_schedule.h"], deps = [ ":stream_assignment", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla/service:buffer_value", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_memory_scheduler", "//tensorflow/compiler/xla/service:hlo_ordering", "//tensorflow/compiler/xla/service:hlo_reachability", - "//tensorflow/compiler/xla/service:hlo_scheduling", "@com_google_absl//absl/memory", ], ) tf_cc_test( - name = "hlo_schedule_test", + name = "gpu_hlo_schedule_test", srcs = [ - "hlo_schedule_test.cc", + "gpu_hlo_schedule_test.cc", ], deps = [ - ":hlo_schedule", + ":gpu_hlo_schedule", ":stream_assignment", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings:str_format", ], ) @@ -851,16 +881,6 @@ cc_library( ], ) -cc_library( - name = "gpu_options", - srcs = ["gpu_options.cc"], - hdrs = ["gpu_options.h"], - deps = [ - "//tensorflow/compiler/xla/service:hlo_module_config", - "//tensorflow/core:lib_internal", - ], -) - cc_library( name = "stream_executor_util", srcs = ["stream_executor_util.cc"], @@ -869,7 +889,9 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo", "//tensorflow/core:stream_executor_no_cuda", ], ) @@ -882,6 +904,7 @@ tf_cc_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", @@ -918,3 +941,42 @@ xla_test( "//tensorflow/core:test_main", ], ) + +cc_library( + name = "gpu_fusible", + srcs = ["gpu_fusible.cc"], + hdrs = ["gpu_fusible.h"], + deps = [ + ":ir_emission_utils", + "//tensorflow/compiler/xla/service:hlo", + ], +) + +tf_cc_test( + name = "gpu_fusible_test", + srcs = ["gpu_fusible_test.cc"], + deps = [ + ":gpu_fusible", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "cudnn_fused_conv_rewriter", + srcs = ["cudnn_fused_conv_rewriter.cc"], + hdrs = ["cudnn_fused_conv_rewriter.h"], + deps = [ + ":backend_configs", + ":ir_emission_utils", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_casting_utils", + "//tensorflow/compiler/xla/service:hlo_pass", + "//tensorflow/compiler/xla/service:pattern_matcher", + "//tensorflow/core:stream_executor_no_cuda", + ], +) diff --git a/tensorflow/compiler/xla/service/gpu/backend_configs.proto b/tensorflow/compiler/xla/service/gpu/backend_configs.proto index 640c6392b8b820c708b853c2a3cea4d4116e85a8..78e14d860e31ace2fcb3f51fb8e0c40a0ea5f3dd 100644 --- a/tensorflow/compiler/xla/service/gpu/backend_configs.proto +++ b/tensorflow/compiler/xla/service/gpu/backend_configs.proto @@ -24,4 +24,18 @@ message CudnnConvBackendConfig { // true, cudnn may choose not to use tensor cores, e.g. because the GPU or // selected algorithm doesn't support it. bool tensor_ops_enabled = 2; + + // The scaling factor multiplied with the convolution result. + double conv_result_scale = 4; + + // Below are the fields related to cuDNN's fused convolution. Refer to + // CudnnConvParams for their meanings. + + // The requested activation (e.g. relu) after the convolution. It is with type + // stream_executor::dnn::ActivationMode. + int64 activation_mode = 3; + + // The scaling factor multiplied with the side input. If no side input buffer + // is provided, this field must be 0. + double side_input_scale = 5; } diff --git a/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc b/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc index e208ad61e331ecac12fe128359da7585a2a3a7b4..528209abc75777440163c2e1512658b8ad36315b 100644 --- a/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc +++ b/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc @@ -62,7 +62,7 @@ StatusOr> BufferAllocations::Builder::Build( if (reinterpret_cast(address.opaque()) % expected_alignment != 0) { return InternalError( - "Address of registered buffer %lld must be a multiple of %llx, but " + "Address of registered buffer %d must be a multiple of %x, but " "was %p", i, kEntryParameterAlignBytes, address.opaque()); } @@ -83,7 +83,7 @@ StatusOr> BufferAllocations::Builder::Build( 0) { return InternalError( "Address returned by memory_allocator->Allocate must be a " - "multiple of %llx, but was %p", + "multiple of 0x%x, but was %p", kXlaAllocatedBufferAlignBytes, buffer.opaque()); } // We do manual memory management within BufferAllocations. Be sure not diff --git a/tensorflow/compiler/xla/service/gpu/buffer_allocations.h b/tensorflow/compiler/xla/service/gpu/buffer_allocations.h index f13eab0dd787a2bfa687c991f9d808568360fd24..14186b8faa68ad8492ea4863fcd7bd746e2eae48 100644 --- a/tensorflow/compiler/xla/service/gpu/buffer_allocations.h +++ b/tensorflow/compiler/xla/service/gpu/buffer_allocations.h @@ -20,10 +20,10 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" namespace xla { diff --git a/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc b/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc index f22c2a8add035ba16a2888e881a287e974db58f0..13c83c9199fb1bbd8b00dbd601afcb677f92bbee 100644 --- a/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc +++ b/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc @@ -124,7 +124,7 @@ StatusOr F16BufferComparator::Create( StatusOr F16BufferComparator::CompareEqualImpl( se::DeviceMemory test_buffer) { if (ref_buffer_.root_buffer().size() != test_buffer.size()) { - return InternalError("Mismatched buffer size: %lld vs %lld", + return InternalError("Mismatched buffer size: %d vs %d", ref_buffer_.root_buffer().size(), test_buffer.size()); } diff --git a/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc b/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc index 8b0426aa27fa3fbc7225dda81cef17e543f1cf28..9ed523998bf07567133fdac0e40b12b8ce4ea3b0 100644 --- a/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc @@ -59,7 +59,7 @@ Status ConditionalThunk::ExecuteOnStream( Status block_status = stream->BlockHostUntilDone(); if (!block_status.ok()) { return InternalError("Failed to retrieve predicate value on stream %p: %s.", - stream, block_status.error_message().c_str()); + stream, block_status.error_message()); } // Execute the true or the false computation depending on the value of the diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc index 854a2f50b2cdfd7c6651424f6aa9e5f2530ad2e8..e1dffad3045808c4f316ccafdda39a174e1560c8 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc @@ -18,77 +18,49 @@ limitations under the License. #include #include "absl/strings/str_cat.h" -#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h" +#include "tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.h" #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" namespace xla { namespace gpu { -using se::dnn::AlgorithmDesc; - ConvolutionThunk::ConvolutionThunk( - CudnnConvKind convolution_kind, const BufferAllocation::Slice& input_buffer, - const BufferAllocation::Slice& filter_buffer, - const BufferAllocation::Slice& output_buffer, - const BufferAllocation::Slice& tuple_result_buffer, - const BufferAllocation::Slice& scratch_buffer, const Shape& input_shape, - const Shape& filter_shape, const Shape& output_shape, const Window& window, - const ConvolutionDimensionNumbers& dim_nums, int64 algorithm, - bool tensor_ops_enabled, const HloInstruction* hlo) - : Thunk(Kind::kConvolution, hlo), - convolution_kind_(convolution_kind), - input_buffer_(input_buffer), - filter_buffer_(filter_buffer), - output_buffer_(output_buffer), - tuple_result_buffer_(tuple_result_buffer), - scratch_buffer_(scratch_buffer), - input_shape_(input_shape), - filter_shape_(filter_shape), - output_shape_(output_shape), - window_(window), - dim_nums_(dim_nums), - algorithm_(algorithm), - tensor_ops_enabled_(tensor_ops_enabled) {} + const HloCustomCallInstruction* cudnn_call, + std::vector operand_slices, + BufferAllocation::Slice result_slice, BufferAllocation::Slice scratch_slice, + BufferAllocation::Slice tuple_result_slice) + : Thunk(Kind::kConvolution, cudnn_call), + cudnn_call_(cudnn_call), + operand_buffers_(std::move(operand_slices)), + result_buffer_(result_slice), + scratch_buffer_(scratch_slice), + tuple_result_buffer_(tuple_result_slice) {} Status ConvolutionThunk::ExecuteOnStream( const BufferAllocations& buffer_allocations, se::Stream* stream, HloExecutionProfiler* profiler) { - se::DeviceMemoryBase input_data = - buffer_allocations.GetDeviceAddress(input_buffer_); - se::DeviceMemoryBase filter_data = - buffer_allocations.GetDeviceAddress(filter_buffer_); - se::DeviceMemoryBase output_data = - buffer_allocations.GetDeviceAddress(output_buffer_); + std::vector operand_se_buffers; + for (const auto& buffer : operand_buffers_) { + operand_se_buffers.push_back(buffer_allocations.GetDeviceAddress(buffer)); + } + + se::DeviceMemoryBase result_buffer = + buffer_allocations.GetDeviceAddress(result_buffer_); + se::DeviceMemoryBase scratch = buffer_allocations.GetDeviceAddress(scratch_buffer_); - se::dnn::AlgorithmConfig algorithm_config( - se::dnn::AlgorithmDesc(algorithm_, tensor_ops_enabled_)); - auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction()); - TF_RETURN_IF_ERROR(RunCudnnConvolution( - convolution_kind_, input_shape_, filter_shape_, output_shape_, input_data, - filter_data, output_data, scratch, window_, dim_nums_, algorithm_config, - stream)); + TF_RETURN_IF_ERROR(RunCudnnConv(cudnn_call_, + absl::MakeSpan(operand_se_buffers), + result_buffer, scratch, stream)); - // Figure out which of output/input/filter is the result produced by - // this op, and write the result tuple. - void* result_ptr = [&] { - switch (convolution_kind_) { - case CudnnConvKind::kForward: - return output_data.opaque(); - case CudnnConvKind::kBackwardInput: - return input_data.opaque(); - case CudnnConvKind::kBackwardFilter: - return filter_data.opaque(); - } - }(); - void* ptrs[] = {result_ptr, scratch.opaque()}; + void* ptrs[] = {result_buffer.opaque(), scratch.opaque()}; se::DeviceMemory tuple_addr( buffer_allocations.GetDeviceAddress(tuple_result_buffer_)); stream->ThenMemcpyH2D(ptrs, &tuple_addr); diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h index f7952787c1db45955c88197e99197ca134b742d1..c71515490c94ef54baad9005509d1813de630159 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h @@ -19,11 +19,12 @@ limitations under the License. #include "absl/types/optional.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" -#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h" +#include "tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.h" #include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/service/gpu/thunk.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" @@ -32,7 +33,7 @@ limitations under the License. namespace xla { namespace gpu { -// This class stores everything that StreamExecutor needs to launch a BNN +// This class stores everything that StreamExecutor needs to launch a DNN // convolution. It is generated by IrEmitter. // // This is thread-compatible. @@ -41,26 +42,12 @@ class ConvolutionThunk : public Thunk { // Constructs a thunk for launching a DNN convolution. When run, it will // write a tuple (result, scratch_memory) into `tuple_result_buffer`. // - // `algorithm` is a cudnn algorithm number. `algorithm == -1` indicates that - // we should use the default (i.e. baseline) cudnn algorithm. - // - // Note that "output" here doesn't refer to the output from running this - // thunk, but rather to the "output" of a hypothetical forward convolution - // that corresponds to this input+filter+output triple. That is, the result - // generated by this thunk is "output" for forward convs, "input" for - // backward-input convs, and "filter" for backward-filter convs. - // - // Semantics of null hlo_instruction argument are as in Thunk. - ConvolutionThunk(CudnnConvKind convolution_kind, - const BufferAllocation::Slice& input_buffer, - const BufferAllocation::Slice& filter_buffer, - const BufferAllocation::Slice& output_buffer, - const BufferAllocation::Slice& tuple_result_buffer, - const BufferAllocation::Slice& scratch_buffer, - const Shape& input_shape, const Shape& filter_shape, - const Shape& output_shape, const Window& window, - const ConvolutionDimensionNumbers& dim_nums, int64 algorithm, - bool tensor_ops_enabled, const HloInstruction* hlo); + // operand_slices should be in the same order as cudnn_call->operands(). + ConvolutionThunk(const HloCustomCallInstruction* cudnn_call, + std::vector operand_slices, + BufferAllocation::Slice result_slice, + BufferAllocation::Slice scratch_slice, + BufferAllocation::Slice tuple_result_slice); ConvolutionThunk(const ConvolutionThunk&) = delete; ConvolutionThunk& operator=(const ConvolutionThunk&) = delete; @@ -71,35 +58,11 @@ class ConvolutionThunk : public Thunk { HloExecutionProfiler* profiler) override; private: - class ScratchAllocator; - - Status Convolve(const se::dnn::BatchDescriptor& input_descriptor, - se::DeviceMemory input_data, - const se::dnn::FilterDescriptor& filter_descriptor, - se::DeviceMemory filter_data, - const se::dnn::BatchDescriptor& output_descriptor, - se::DeviceMemory output_data, - const se::dnn::ConvolutionDescriptor& convolution_descriptor, - const se::dnn::AlgorithmConfig& algorithm_config, - se::Stream* stream, ScratchAllocator* scratch_allocator, - se::dnn::ProfileResult* profile_result); - - const CudnnConvKind convolution_kind_; - - const BufferAllocation::Slice input_buffer_; - const BufferAllocation::Slice filter_buffer_; - const BufferAllocation::Slice output_buffer_; - const BufferAllocation::Slice tuple_result_buffer_; - const BufferAllocation::Slice scratch_buffer_; - - const Shape input_shape_; - const Shape filter_shape_; - const Shape output_shape_; - - const Window window_; - const ConvolutionDimensionNumbers dim_nums_; - int64 algorithm_; - bool tensor_ops_enabled_; + const HloCustomCallInstruction* cudnn_call_; + std::vector operand_buffers_; + BufferAllocation::Slice result_buffer_; + BufferAllocation::Slice scratch_buffer_; + BufferAllocation::Slice tuple_result_buffer_; }; } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h index 6e2e330edd4beabe0b395f05b80d57612d63f110..c3f58508ddd4451312325b0d440473515812dac9 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h +++ b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h @@ -52,7 +52,7 @@ namespace gpu { // The GPU backend does not implement a lowering for the batchnorm HLOs -- it // expects them to be lowered to cudnn calls via this pass or to HLO soup via // BatchNormRewriter. -class CudnnBatchNormRewriter : public HloPassInterface { +class CudnnBatchNormRewriter : public HloModulePass { public: absl::string_view name() const override { return "cudnn_batchnorm_rewriter"; } StatusOr Run(HloModule* module) override; diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc index 18a76e8c26150db47c064d76492ef6c1521e2745..bc3c6f72f6799f84169748465d62c3f2a306d5fc 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc @@ -22,7 +22,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.cc similarity index 60% rename from tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc rename to tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.cc index 3d421ebb693a64229746d5b90107039507a3d457..6d6780fa1c7b0c636eb771c40e74f074cd8c4c4b 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.cc @@ -13,14 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h" +#include "tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "absl/types/optional.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" #include "tensorflow/compiler/xla/service/gpu/buffer_comparator.h" #include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/platform/mutex.h" @@ -59,8 +61,8 @@ StatusOr> ScratchAllocator::AllocateBytes( if (byte_size > GetMemoryLimitInBytes(stream)) { return se::port::Status( se::port::error::RESOURCE_EXHAUSTED, - tensorflow::strings::Printf( - "Allocating %lld bytes exceeds the memory limit of %lld bytes.", + absl::StrFormat( + "Allocating %d bytes exceeds the memory limit of %d bytes.", byte_size, GetMemoryLimitInBytes(stream))); } @@ -74,54 +76,24 @@ StatusOr> ScratchAllocator::AllocateBytes( return se::DeviceMemory(buffer_addr); } -// Determines whether we can safely perform a winograd non-fused convolution for -// the given input and output shapes. This works around b/68264959, an integer -// overflow in cuDNNv5 and cuDNNv6. -bool ShouldIncludeWinogradNonfusedAlgo(const Shape& input_shape, - const Shape& output_shape, - const ConvolutionDimensionNumbers& dnums, - se::StreamExecutor* stream_exec) { - // Skip this check for cudnn7 and newer. - auto version = stream_exec->AsDnn()->GetVersion(); - if (version.ok() && version.ValueOrDie().major_version() >= 7) { - return true; - } - - int64 batch = input_shape.dimensions(dnums.input_batch_dimension()); - int64 in_depths = input_shape.dimensions(dnums.input_feature_dimension()); - int64 in_rows = input_shape.dimensions(dnums.input_spatial_dimensions(0)); - int64 in_cols = - dnums.input_spatial_dimensions_size() == 1 - ? 1 - : input_shape.dimensions(dnums.input_spatial_dimensions(1)); - int64 out_depths = output_shape.dimensions(dnums.output_feature_dimension()); - - int64 total_size = CeilOfRatio(batch, int64{16}) * - std::max(in_depths, out_depths) * in_cols * in_rows * - sizeof(float); - - const int64 threshold = 1L << 31; - return total_size < threshold; -} - std::vector GetAlgorithms(CudnnConvKind kind, - bool with_winograd_nonfused, se::StreamExecutor* stream_exec) { std::vector algorithms; + bool succ = false; switch (kind) { case CudnnConvKind::kBackwardFilter: - CHECK(stream_exec->GetConvolveBackwardFilterAlgorithms( - with_winograd_nonfused, &algorithms)); + succ = + stream_exec->GetConvolveBackwardFilterAlgorithms(true, &algorithms); break; case CudnnConvKind::kBackwardInput: - CHECK(stream_exec->GetConvolveBackwardDataAlgorithms( - with_winograd_nonfused, &algorithms)); + succ = stream_exec->GetConvolveBackwardDataAlgorithms(true, &algorithms); break; case CudnnConvKind::kForward: - CHECK(stream_exec->GetConvolveAlgorithms(with_winograd_nonfused, - &algorithms)); + case CudnnConvKind::kForwardActivation: + succ = stream_exec->GetConvolveAlgorithms(true, &algorithms); break; } + DCHECK(succ); return algorithms; } @@ -173,16 +145,12 @@ tensorflow::mutex_lock LockGpu(const se::StreamExecutor* stream_exec) { // cache misses and doing extra work. Overall, caching doesn't seem worth the // trouble, but we may want to revisit this if we ever find a model where // caching would speed up compilation a lot. -StatusOr> -CudnnConvolutionAlgorithmPicker::PickBestAlgorithm( - CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape, - const Shape& output_shape, const Window& window, - const ConvolutionDimensionNumbers& dnums, HloInstruction* instr) { - CHECK_EQ(input_shape.element_type(), filter_shape.element_type()); - CHECK_EQ(input_shape.element_type(), output_shape.element_type()); +StatusOr +CudnnConvAlgorithmPicker::PickBestAlgorithm(HloCustomCallInstruction* instr) { // TODO(timshen): for now only check fp16. It can be expanded to other types, // with some work on the HLO routines. - const bool cross_check_enabled = input_shape.element_type() == xla::F16; + const bool cross_check_enabled = + instr->shape().tuple_shapes(0).element_type() == xla::F16; // Don't run this function concurrently on the same GPU. // @@ -191,6 +159,12 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm( // concurrently and then run them sequentially. tensorflow::mutex_lock lock = LockGpu(stream_exec_); + // Make sure any previous activity on this executor is done. We don't want to + // interfere with programs that are still running on the GPU. + if (!stream_exec_->SynchronizeAllActivity()) { + return InternalError("Failed to synchronize GPU for autotuning."); + } + // Create a stream for us to do our work on. se::Stream stream{stream_exec_}; stream.Init(); @@ -203,38 +177,24 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm( if (allocator_ != nullptr) { allocator = allocator_; } else { - se_allocator.emplace( - stream_exec_->platform(), - tensorflow::gtl::ArraySlice({stream_exec_})); + se_allocator.emplace(stream_exec_->platform(), + absl::Span({stream_exec_})); allocator = &*se_allocator; } - // Allocate space for the input, filter, and output of the convolution. We - // use a ScratchAllocator for this instead of calling allocator_ directly so - // that our allocations don't leak. - ScratchAllocator input_output_allocator(device_ordinal, allocator); - TF_ASSIGN_OR_RETURN(DeviceMemoryBase input_buf, - input_output_allocator.AllocateBytes( - &stream, ShapeUtil::ByteSizeOf(input_shape))); - TF_ASSIGN_OR_RETURN(DeviceMemoryBase filter_buf, - input_output_allocator.AllocateBytes( - &stream, ShapeUtil::ByteSizeOf(filter_shape))); - TF_ASSIGN_OR_RETURN(DeviceMemoryBase output_buf, - input_output_allocator.AllocateBytes( - &stream, ShapeUtil::ByteSizeOf(output_shape))); - - if (cross_check_enabled) { - // Broadcast a constant to the buffer, instead of zeroing the buffer. A - // non-zero constant is useful for the cross checking, because zero-inputs - // may not always reveal the bugs. - const auto initialize_f16 = [&stream](DeviceMemoryBase buffer) { + const auto initialize_buffer = [&stream, cross_check_enabled]( + DeviceMemoryBase buffer) { + if (cross_check_enabled) { + // Broadcast a constant to the buffer, instead of zeroing the buffer. A + // non-zero constant is useful for the cross checking, because zero-inputs + // may not always reveal the bugs. CHECK_EQ(0, (uintptr_t)buffer.opaque() % 4); size_t left_over_bytes = buffer.size() % 4; CHECK_EQ(0, left_over_bytes % 2); constexpr float kBroadcastedConstant = 0.1f; - Eigen::half halfs[2] = {Eigen::half(kBroadcastedConstant), - Eigen::half(kBroadcastedConstant)}; + static const Eigen::half halfs[2] = {Eigen::half(kBroadcastedConstant), + Eigen::half(kBroadcastedConstant)}; uint32 bits; static_assert(sizeof(bits) == sizeof(halfs), ""); memcpy(&bits, halfs, sizeof(bits)); @@ -245,54 +205,56 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm( DeviceMemoryBase left_over( static_cast(buffer.opaque()) + aligned_size, left_over_bytes); stream.ThenMemcpy(&left_over, halfs, left_over_bytes); - }; - initialize_f16(input_buf); - initialize_f16(filter_buf); - initialize_f16(output_buf); - } else { - // Although we don't have evidence this matters, zero out the buffers before - // autotuning. It's conceivable that using uninitialized memory as the - // inputs might affect performance if e.g. the inputs contain denormals, and - // this is easy enough. - stream.ThenMemZero(&input_buf, input_buf.size()) - .ThenMemZero(&filter_buf, filter_buf.size()) - .ThenMemZero(&output_buf, output_buf.size()); - } - TF_RETURN_IF_ERROR(stream.BlockHostUntilDone()); - - DeviceMemoryBase* result_buf = [&] { - switch (kind) { - case CudnnConvKind::kBackwardFilter: - return &filter_buf; - case CudnnConvKind::kBackwardInput: - return &input_buf; - case CudnnConvKind::kForward: - return &output_buf; + } else { + // Although we don't have evidence this matters, zero out the buffers + // before autotuning. It's conceivable that using uninitialized memory as + // the inputs might affect performance if e.g. the inputs contain + // denormals, and this is easy enough. + stream.ThenMemZero(&buffer, buffer.size()); } - }(); + }; + + // Allocate space for the input, filter, and output of the convolution. We + // use a ScratchAllocator for this instead of calling allocator_ directly so + // that our allocations don't leak. + ScratchAllocator input_output_allocator(device_ordinal, allocator); + std::vector operand_buffers; + for (const auto* operand : instr->operands()) { + TF_ASSIGN_OR_RETURN(auto buffer, + input_output_allocator.AllocateBytes( + &stream, ShapeUtil::ByteSizeOf(operand->shape()))); + initialize_buffer(buffer); + operand_buffers.push_back(buffer); + } + TF_ASSIGN_OR_RETURN( + auto result_buffer, + input_output_allocator.AllocateBytes( + &stream, ShapeUtil::ByteSizeOf(instr->shape().tuple_shapes(0)))); + initialize_buffer(result_buffer); - const bool use_winograd_nonfused = ShouldIncludeWinogradNonfusedAlgo( - input_shape, output_shape, dnums, stream_exec_); se::dnn::ProfileResult best_result; int64 best_result_bytes_used = 0; + TF_ASSIGN_OR_RETURN(auto backend_config, + instr->backend_config()); optional comparator; // Use the first algorithm that's supported as reference. There isn't a // particular reason to use it, as any algorithm sufficies. It doesn't make // this algorithm considered correct, though. optional first_algorithm; - for (const AlgorithmDesc& alg : - GetAlgorithms(kind, use_winograd_nonfused, stream_exec_)) { + TF_ASSIGN_OR_RETURN(CudnnConvKind kind, GetCudnnConvKind(instr)); + for (const AlgorithmDesc& alg : GetAlgorithms(kind, stream_exec_)) { ScratchAllocator scratch_allocator(device_ordinal, allocator); se::dnn::ProfileResult profile_result; VLOG(3) << "Trying algorithm " << AlgorithmToString(alg) << " for " << instr->ToString(); + backend_config.set_algorithm(alg.algo_id()); + backend_config.set_tensor_ops_enabled(alg.tensor_ops_enabled()); + TF_RETURN_IF_ERROR(instr->set_backend_config(backend_config)); bool launch_ok = - RunCudnnConvolution(kind, input_shape, filter_shape, output_shape, - input_buf, filter_buf, output_buf, - &scratch_allocator, window, dnums, - AlgorithmConfig(alg), &stream, &profile_result) + RunCudnnConv(instr, absl::MakeSpan(operand_buffers), result_buffer, + &scratch_allocator, &stream, &profile_result) .ok(); if (launch_ok && profile_result.is_valid()) { @@ -303,7 +265,7 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm( .xla_gpu_crash_on_verification_failures(); if (comparator.has_value()) { StatusOr result = comparator->CompareEqual( - se::DeviceMemory(*result_buf)); + se::DeviceMemory(result_buffer)); if (!result.ok()) { LOG(ERROR) << "Unable to compare " << AlgorithmToString(*first_algorithm) << " against " @@ -321,7 +283,7 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm( } } else if (cross_check_enabled) { auto comp = F16BufferComparator::Create( - se::DeviceMemory(*result_buf), compiler_, allocator, + se::DeviceMemory(result_buffer), compiler_, allocator, &stream); if (comp.ok()) { comparator.emplace(comp.ConsumeValueOrDie()); @@ -353,83 +315,53 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm( << AlgorithmToString(best_result.algorithm()) << ", takes " << best_result.elapsed_time_in_ms() << "ms, and uses " << best_result_bytes_used << "B of scratch memory."; - return std::make_tuple(best_result.algorithm().algo_id(), - best_result.algorithm().tensor_ops_enabled(), - best_result_bytes_used); + return AutotuneResult{best_result.algorithm().algo_id(), + best_result.algorithm().tensor_ops_enabled(), + best_result_bytes_used, + absl::Milliseconds(best_result.elapsed_time_in_ms())}; } return InternalError( "All algorithms tried for convolution %s failed. Falling back to " "default algorithm.", - instr->ToString().c_str()); + instr->ToString()); } -StatusOr CudnnConvolutionAlgorithmPicker::RunOnInstruction( +StatusOr CudnnConvAlgorithmPicker::RunOnInstruction( HloInstruction* instr) { CHECK(IsCustomCallToDnnConvolution(*instr)); - const auto& call_target = instr->custom_call_target(); - const auto& lhs_shape = instr->operand(0)->shape(); - const auto& rhs_shape = instr->operand(1)->shape(); - const auto& conv_result_shape = instr->shape().tuple_shapes(0); - StatusOr> alg_scratch_and_tc; - if (call_target == kCudnnConvForwardCallTarget) { - alg_scratch_and_tc = - PickBestAlgorithm(CudnnConvKind::kForward, /*input_shape=*/lhs_shape, - /*filter_shape=*/rhs_shape, - /*output_shape=*/conv_result_shape, instr->window(), - instr->convolution_dimension_numbers(), instr); - } else if (call_target == kCudnnConvBackwardInputCallTarget) { - alg_scratch_and_tc = PickBestAlgorithm( - CudnnConvKind::kBackwardInput, /*input_shape=*/conv_result_shape, - /*filter_shape=*/rhs_shape, /*output_shape=*/lhs_shape, instr->window(), - instr->convolution_dimension_numbers(), instr); - } else if (call_target == kCudnnConvBackwardFilterCallTarget) { - alg_scratch_and_tc = PickBestAlgorithm( - CudnnConvKind::kBackwardFilter, /*input_shape=*/lhs_shape, - /*filter_shape=*/conv_result_shape, /*output_shape=*/rhs_shape, - instr->window(), instr->convolution_dimension_numbers(), instr); - } else { - LOG(FATAL) << "Unknown custom call target for cudnn conv: " - << instr->ToString(); - } - - if (!alg_scratch_and_tc.ok()) { - LOG(ERROR) << alg_scratch_and_tc.status(); + StatusOr best_algo_or = + PickBestAlgorithm(Cast(instr)); + if (!best_algo_or.ok()) { + LOG(ERROR) << best_algo_or.status(); return false; } - int64 algorithm; - bool tensor_ops_enabled; - int64 scratch_bytes; - - std::tie(algorithm, tensor_ops_enabled, scratch_bytes) = - alg_scratch_and_tc.ConsumeValueOrDie(); - - VLOG(1) << "Setting cudnn conv to use algorithm " << algorithm << " and " - << NumBytesToString(scratch_bytes) + auto best_algo = std::move(best_algo_or).ValueOrDie(); + VLOG(1) << "Setting cudnn conv to use algorithm " << best_algo.algorithm + << " and " << NumBytesToString(best_algo.scratch_bytes) << " of scratch memory: " << instr->ToString() - << " tensor_ops_enabled: " << tensor_ops_enabled; + << " tensor_ops_enabled: " << best_algo.tensor_ops_enabled; // Replace instr with a new CustomCall which has the correct algorithm, and // whose output shape has the appropriate amount of scratch memory. HloComputation* computation = instr->parent(); - Shape new_call_shape = - ShapeUtil::MakeTupleShape({instr->shape().tuple_shapes(0), - ShapeUtil::MakeShape(U8, {scratch_bytes})}); - - CudnnConvBackendConfig backend_config; - backend_config.set_algorithm(algorithm); - backend_config.set_tensor_ops_enabled(tensor_ops_enabled); - - HloInstruction* new_call = - computation->AddInstruction(HloInstruction::CreateCustomCall( - new_call_shape, - {instr->mutable_operand(0), instr->mutable_operand(1)}, - instr->custom_call_target())); - new_call->set_window(instr->window()); - new_call->set_convolution_dimension_numbers( - instr->convolution_dimension_numbers()); + Shape new_call_shape = ShapeUtil::MakeTupleShape( + {instr->shape().tuple_shapes(0), + ShapeUtil::MakeShape(U8, {best_algo.scratch_bytes})}); + + TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig backend_config, + instr->backend_config()); + backend_config.set_algorithm(best_algo.algorithm); + backend_config.set_tensor_ops_enabled(best_algo.tensor_ops_enabled); + + HloInstruction* new_call = computation->AddInstruction( + instr->CloneWithNewOperands(new_call_shape, instr->operands())); + + VLOG(1) << "Replacing convolution " << instr->ToString() << " with " + << new_call->ToString(); + TF_RETURN_IF_ERROR(new_call->set_backend_config(backend_config)); // Repackage new_call so it has the same shape as the original call, namely @@ -445,7 +377,7 @@ StatusOr CudnnConvolutionAlgorithmPicker::RunOnInstruction( return true; } -StatusOr CudnnConvolutionAlgorithmPicker::RunOnComputation( +StatusOr CudnnConvAlgorithmPicker::RunOnComputation( HloComputation* computation) { std::vector convs; for (auto* instr : computation->instructions()) { @@ -462,7 +394,7 @@ StatusOr CudnnConvolutionAlgorithmPicker::RunOnComputation( return changed; } -StatusOr CudnnConvolutionAlgorithmPicker::Run(HloModule* module) { +StatusOr CudnnConvAlgorithmPicker::Run(HloModule* module) { bool changed = false; for (HloComputation* computation : module->MakeNonfusionComputations()) { TF_ASSIGN_OR_RETURN(bool result, RunOnComputation(computation)); diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h b/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.h similarity index 68% rename from tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h rename to tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.h index f76d273e8c641dfbdbba38eb161ab8a00a19e1f8..642af787afc71586d722ecc7e529ed8b3fa64d33 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.h @@ -13,13 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_ALGORITHM_PICKER_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_ALGORITHM_PICKER_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONV_ALGORITHM_PICKER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONV_ALGORITHM_PICKER_H_ +#include "absl/time/time.h" #include "absl/types/optional.h" #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" -#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h" +#include "tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -29,29 +31,32 @@ namespace gpu { // Modifies CustomCalls to cudnn convolutions, choosing the best algorithm for // each and adding explicit scratch space to the CustomCalls. -class CudnnConvolutionAlgorithmPicker : public HloPassInterface { +class CudnnConvAlgorithmPicker : public HloModulePass { public: // If the `allocator` parameter is not null, we will use it to allocate temp // memory while timing the various convolution algorithms. If it's null, // we'll use the default allocator on the StreamExecutor. - CudnnConvolutionAlgorithmPicker(se::StreamExecutor* stream_exec, - DeviceMemoryAllocator* allocator, - Compiler* compiler) + CudnnConvAlgorithmPicker(se::StreamExecutor* stream_exec, + DeviceMemoryAllocator* allocator, Compiler* compiler) : stream_exec_(stream_exec), allocator_(allocator), compiler_(compiler) {} absl::string_view name() const override { - return "cudnn-convolution-algorithm-picker"; + return "cudnn-conv-algorithm-picker"; } StatusOr Run(HloModule* module) override; private: + struct AutotuneResult { + int64 algorithm; + bool tensor_ops_enabled; + int64 scratch_bytes; + absl::Duration runtime; + }; + StatusOr RunOnComputation(HloComputation* computation); StatusOr RunOnInstruction(HloInstruction* instr); - StatusOr> PickBestAlgorithm( - CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape, - const Shape& output_shape, const Window& window, - const ConvolutionDimensionNumbers& dnums, HloInstruction* instr); + StatusOr PickBestAlgorithm(HloCustomCallInstruction* instr); se::StreamExecutor* stream_exec_; // never null DeviceMemoryAllocator* allocator_; // may be null @@ -61,4 +66,4 @@ class CudnnConvolutionAlgorithmPicker : public HloPassInterface { } // namespace gpu } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_ALGORITHM_PICKER_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONV_ALGORITHM_PICKER_H_ diff --git a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc b/tensorflow/compiler/xla/service/gpu/cudnn_conv_pad_for_tensor_cores.cc similarity index 50% rename from tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc rename to tensorflow/compiler/xla/service/gpu/cudnn_conv_pad_for_tensor_cores.cc index 79f7d31816baf0b95b967771b956a9c06ac81e91..5aa4f839f4be5f1060480fea98775f8ffada0bdd 100644 --- a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_pad_for_tensor_cores.cc @@ -13,44 +13,27 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h" +#include "tensorflow/compiler/xla/service/gpu/cudnn_conv_pad_for_tensor_cores.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/window_util.h" namespace xla { namespace gpu { -using tensorflow::gtl::ArraySlice; - -// We want the input/output feature counts of an f16 conv to be factors of 8, -// because without this cudnn can't use tensor cores on the conv. -static constexpr int64 kDesiredNumFeaturesFactor = 8; - // We won't pad a conv if doing so increases the total number of bytes in the // lhs, rhs, or result by more than this amount. // // TODO(jlebar): This number was tuned experimentally. It represents a // compromise on our current benchmarks; it speeds some up significantly, and // doesn't slow any down. But we can observe by changing this value that -// there's additional room for speedups. Achieving those speedups without also -// slowing other things down will likely require a more sophisticated heuristic, -// possibly some form of auto-tuning. -static constexpr double kMaxBytesTouchedIncrease = 1.2; - -// Pads the given dimensions in the given shape up to a multiple of -// kDesiredNumFeaturesFactor. -static Shape PadShape(Shape s, ArraySlice dims) { - for (int64 dim : dims) { - int64 dim_to_pad_size = s.dimensions(dim); - int64 new_dim_to_pad_size = - RoundUpToNearest(dim_to_pad_size, kDesiredNumFeaturesFactor); - s.set_dimensions(dim, new_dim_to_pad_size); - } - return s; -} +// there's additional room for speedups. Achieving those speedups without +// also slowing other things down will likely require a more sophisticated +// heuristic, possibly some form of auto-tuning. +static constexpr double kMaxBytesTouchedIncrease = 1.35; // Creates and returns an HLO that zero-pads one or more dimensions in the given // instruction so that its shape is equal to the given shape. @@ -64,8 +47,8 @@ static HloInstruction* PadInstruction(HloInstruction* instr, HloComputation* comp = instr->parent(); const Shape& shape = instr->shape(); - auto* zero = comp->AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::Zero(shape.element_type()).CloneToUnique())); + auto* zero = comp->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::Zero(shape.element_type()))); PaddingConfig pad_config = MakeNoPaddingConfig(ShapeUtil::Rank(shape)); @@ -87,90 +70,19 @@ static HloInstruction* PadInstruction(HloInstruction* instr, HloInstruction::CreatePad(new_shape, instr, zero, pad_config)); } -// Pads the input/output feature dimensions of the given cudnn convolution -// custom-call to be multiples of kDesiredNumFeaturesFactor. -static StatusOr PadFeaturesDims(HloInstruction* conv) { +// Modifies the given convolution to have the given LHS/RHS/result shapes. +static Status PadConv(HloCustomCallInstruction* conv, + const Shape& new_lhs_shape, const Shape& new_rhs_shape, + const Shape& new_result_shape) { CHECK_EQ(0, conv->shape().tuple_shapes(1).dimensions(0)) << "conv must use 0 scratch bytes, i.e. this pass must be run " - "before CudnnConvolutionAlgorithmPicker."; + "before CudnnConvAlgorithmPicker."; - const auto& target = conv->custom_call_target(); - const auto& dnums = conv->convolution_dimension_numbers(); auto* lhs = conv->mutable_operand(0); auto* rhs = conv->mutable_operand(1); - const Shape& result_shape = conv->shape().tuple_shapes(0); - - Shape new_lhs_shape = [&] { - if (target == kCudnnConvForwardCallTarget || - target == kCudnnConvBackwardFilterCallTarget) { - // LHS is "input". - return PadShape(lhs->shape(), {dnums.input_feature_dimension()}); - } - CHECK_EQ(target, kCudnnConvBackwardInputCallTarget); - // LHS is "output". - return PadShape(lhs->shape(), {dnums.output_feature_dimension()}); - }(); - - Shape new_rhs_shape = [&] { - if (target == kCudnnConvForwardCallTarget || - target == kCudnnConvBackwardInputCallTarget) { - // RHS is "filter". - return PadShape(rhs->shape(), {dnums.kernel_input_feature_dimension(), - dnums.kernel_output_feature_dimension()}); - } - CHECK_EQ(target, kCudnnConvBackwardFilterCallTarget); - // RHS is "output". - return PadShape(rhs->shape(), {dnums.output_feature_dimension()}); - }(); - - if (ShapeUtil::Equal(lhs->shape(), new_lhs_shape) && - ShapeUtil::Equal(rhs->shape(), new_rhs_shape)) { - VLOG(3) << "No need to pad features of " << conv->ToString(); - return false; - } - - Shape new_result_shape = [&] { - if (target == kCudnnConvForwardCallTarget) { - // Result is "output". - return PadShape(result_shape, {dnums.output_feature_dimension()}); - } - if (target == kCudnnConvBackwardInputCallTarget) { - // Result is "input". - return PadShape(result_shape, {dnums.input_feature_dimension()}); - } - CHECK_EQ(target, kCudnnConvBackwardFilterCallTarget); - // Result is "filter". - return PadShape(result_shape, {dnums.kernel_input_feature_dimension(), - dnums.kernel_output_feature_dimension()}); - }(); - - // Check that padding wouldn't increase the total bytes read/written by this - // operation too much. - auto check_size_increase = [&](const Shape& old_shape, - const Shape& new_shape) { - int64 old_bytes = ShapeUtil::ByteSizeOf(old_shape); - int64 new_bytes = ShapeUtil::ByteSizeOf(new_shape); - if (new_bytes <= old_bytes * kMaxBytesTouchedIncrease) { - return true; - } - VLOG(3) << "Not padding convolution; doing so would change input / result " - "shape from " - << ShapeUtil::HumanString(old_shape) << " to " - << ShapeUtil::HumanString(new_shape) << ", a size increase of " - << new_bytes / static_cast(old_bytes) << "x > " - << kMaxBytesTouchedIncrease << "x: " << conv->ToString(); - return false; - }; - if (!check_size_increase(lhs->shape(), new_lhs_shape) || - !check_size_increase(rhs->shape(), new_rhs_shape) || - !check_size_increase(result_shape, new_result_shape)) { - return false; - } - - // OK, let's do the transformation! - auto* new_lhs = PadInstruction(lhs, new_lhs_shape); auto* new_rhs = PadInstruction(rhs, new_rhs_shape); + const Shape& result_shape = conv->shape().tuple_shapes(0); CHECK(new_lhs != lhs || new_rhs != rhs) << "We should have had to pad either LHS or RHS."; @@ -203,26 +115,124 @@ static StatusOr PadFeaturesDims(HloInstruction* conv) { VLOG(2) << "Padded features of " << conv->ToString() << ", replaced with " << new_conv->ToString(); - TF_RETURN_IF_ERROR(conv->parent()->ReplaceInstruction(conv, new_conv)); + return conv->parent()->ReplaceInstruction(conv, new_conv); +} + +static StatusOr PadForTensorCores(HloCustomCallInstruction* conv) { + TF_ASSIGN_OR_RETURN(auto kind, GetCudnnConvKind(conv)); + const auto& dnums = conv->convolution_dimension_numbers(); + auto* lhs = conv->mutable_operand(0); + auto* rhs = conv->mutable_operand(1); + const Shape& result_shape = conv->shape().tuple_shapes(0); + + // Nothing to do on non-f16 convolutions. + if (result_shape.element_type() != PrimitiveType::F16) { + return false; + } + + // TODO(timshen): Don't skip forward-activation convs if we find a benchmark + // where there's a speedup. + if (kind == CudnnConvKind::kForwardActivation) { + return false; + } + + Shape new_lhs_shape = lhs->shape(); + Shape new_rhs_shape = rhs->shape(); + Shape new_result_shape = conv->shape().tuple_shapes(0); + + // new_{input,filter_output}_shape points to the appropriate one of + // new_{lhs,rhs,result}_shape. + Shape* new_input_shape; + Shape* new_filter_shape; + Shape* new_output_shape; + std::tie(new_input_shape, new_filter_shape, new_output_shape) = [&] { + switch (kind) { + case CudnnConvKind::kForward: + case CudnnConvKind::kForwardActivation: + return std::make_tuple(&new_lhs_shape, &new_rhs_shape, + &new_result_shape); + case CudnnConvKind::kBackwardInput: + return std::make_tuple(&new_result_shape, &new_rhs_shape, + &new_lhs_shape); + case CudnnConvKind::kBackwardFilter: + return std::make_tuple(&new_lhs_shape, &new_result_shape, + &new_rhs_shape); + } + }(); + + // If there are 3 input features and 32 or 64 output features, pad the input + // features to 4. Otherwise, try padding to multiples of 8 and check that + // this doesn't make any of the conv buffers too much larger. + auto input_features = + new_input_shape->dimensions(dnums.input_feature_dimension()); + auto output_features = + new_output_shape->dimensions(dnums.output_feature_dimension()); + if (input_features == 3 && (output_features == 32 || output_features == 64)) { + new_input_shape->set_dimensions(dnums.input_feature_dimension(), 4); + new_filter_shape->set_dimensions(dnums.kernel_input_feature_dimension(), 4); + } else { + auto pad_dim = [](Shape* s, int64 dim) { + s->set_dimensions(dim, RoundUpToNearest(s->dimensions(dim), 8)); + }; + pad_dim(new_input_shape, dnums.input_feature_dimension()); + pad_dim(new_filter_shape, dnums.kernel_input_feature_dimension()); + pad_dim(new_filter_shape, dnums.kernel_output_feature_dimension()); + pad_dim(new_output_shape, dnums.output_feature_dimension()); + + // Check that padding wouldn't increase the total bytes read/written by this + // operation too much. + auto check_size_increase = [&](const Shape& old_shape, + const Shape& new_shape) { + int64 old_bytes = ShapeUtil::ByteSizeOf(old_shape); + int64 new_bytes = ShapeUtil::ByteSizeOf(new_shape); + if (new_bytes <= old_bytes * kMaxBytesTouchedIncrease) { + return true; + } + VLOG(3) + << "Not padding convolution; doing so would change input / result " + "shape from " + << ShapeUtil::HumanString(old_shape) << " to " + << ShapeUtil::HumanString(new_shape) << ", a size increase of " + << new_bytes / static_cast(old_bytes) << "x > " + << kMaxBytesTouchedIncrease << "x: " << conv->ToString(); + return false; + }; + + if (!check_size_increase(lhs->shape(), new_lhs_shape) || + !check_size_increase(rhs->shape(), new_rhs_shape) || + !check_size_increase(result_shape, new_result_shape)) { + return false; + } + } + + if (ShapeUtil::Equal(lhs->shape(), new_lhs_shape) && + ShapeUtil::Equal(rhs->shape(), new_rhs_shape)) { + VLOG(3) << "No need to pad features of " << conv->ToString(); + return false; + } + + // OK, let's do the transformation! + TF_RETURN_IF_ERROR( + PadConv(conv, new_lhs_shape, new_rhs_shape, new_result_shape)); return true; } -static std::vector GetRelevantConvs(HloComputation* comp) { - std::vector convs; +static std::vector GetRelevantConvs( + HloComputation* comp) { + std::vector convs; for (HloInstruction* instr : comp->instructions()) { - if (IsCustomCallToDnnConvolution(*instr) && - instr->operand(0)->shape().element_type() == F16) { - convs.push_back(instr); + if (IsCustomCallToDnnConvolution(*instr)) { + convs.push_back(Cast(instr)); } } return convs; } -StatusOr PadForTensorCores::Run(HloModule* module) { +StatusOr CudnnConvPadForTensorCores::Run(HloModule* module) { bool changed = false; for (HloComputation* comp : module->MakeNonfusionComputations()) { - for (HloInstruction* conv : GetRelevantConvs(comp)) { - TF_ASSIGN_OR_RETURN(bool result, PadFeaturesDims(conv)); + for (HloCustomCallInstruction* conv : GetRelevantConvs(comp)) { + TF_ASSIGN_OR_RETURN(bool result, PadForTensorCores(conv)); changed |= result; } } diff --git a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h b/tensorflow/compiler/xla/service/gpu/cudnn_conv_pad_for_tensor_cores.h similarity index 51% rename from tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h rename to tensorflow/compiler/xla/service/gpu/cudnn_conv_pad_for_tensor_cores.h index 11dc56a64fda74cab12024e5f2c6fa2f63c9167d..d4e51e86c1bf2c1f9aef2eed642604092033a538 100644 --- a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_pad_for_tensor_cores.h @@ -13,26 +13,30 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_PAD_FOR_TENSOR_CORES_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_PAD_FOR_TENSOR_CORES_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONV_PAD_FOR_TENSOR_CORES_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONV_PAD_FOR_TENSOR_CORES_H_ #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" namespace xla { namespace gpu { -// Ensures that f16 cudnn convolutions have input/output channel dimensions that -// are multiples of 8, inserting pads/slices as necessary. +// Adds padding to cudnn convolutions to make them run faster on GPUs with +// tensor cores. // -// This is useful primarily for Volta and newer GPUs, where tensor cores can -// only be used if the channel dims are multiples of 8. It's probably the -// opposite of useful on other GPUs, so you should check what GPU you're -// targeting before running this pass. +// - f16 convolutions are padded to have input/output channel dimensions that +// are multiples of 8, so that we can use tensor cores. +// +// - f16 convolutions with 3 input channels and 32 or 64 output channels are +// padded to 4 input channels. There's a special-cased cudnn algorithm just +// for this. +// +// Don't run this pass on GPUs without tensor cores -- it will make them slower! // // TODO(jlebar): Also pad dots. -class PadForTensorCores : public HloPassInterface { +class CudnnConvPadForTensorCores : public HloModulePass { public: - absl::string_view name() const override { return "pad for tensor cores"; } + absl::string_view name() const override { return "cudnn-conv-pad-for-speed"; } StatusOr Run(HloModule* module) override; }; @@ -40,4 +44,4 @@ class PadForTensorCores : public HloPassInterface { } // namespace gpu } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_PAD_FOR_TENSOR_CORES_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONV_PAD_FOR_TENSOR_CORES_H_ diff --git a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores_test.cc b/tensorflow/compiler/xla/service/gpu/cudnn_conv_pad_for_tensor_cores_test.cc similarity index 63% rename from tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores_test.cc rename to tensorflow/compiler/xla/service/gpu/cudnn_conv_pad_for_tensor_cores_test.cc index 104af48c82ab1be9792eff11406af8d2a439e954..fa3afa6a5d318c399dc38e8934199b5a1393669e 100644 --- a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores_test.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_pad_for_tensor_cores_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h" +#include "tensorflow/compiler/xla/service/gpu/cudnn_conv_pad_for_tensor_cores.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" @@ -29,15 +29,10 @@ namespace { namespace op = xla::testing::opcode_matchers; using ::testing::_; -class PadForTensorCoresTest : public HloVerifiedTestBase { - public: - PadForTensorCoresTest() - : HloVerifiedTestBase(/*layout_sensitive=*/false, - /*allow_mixed_precision=*/false) {} -}; +class CudnnConvPadForTensorCoresTest : public HloVerifiedTestBase {}; -TEST_F(PadForTensorCoresTest, PadF16ForwardConvInputChannels) { - ParseAndVerifyModule(R"( +TEST_F(CudnnConvPadForTensorCoresTest, PadF16ForwardConvInputChannels) { + auto module = ParseAndReturnVerifiedModule(R"( HloModule TestModule ENTRY TestComputation { @@ -46,11 +41,12 @@ TEST_F(PadForTensorCoresTest, PadF16ForwardConvInputChannels) { ROOT result = (f16[10,20,30,40], u8[0]) custom-call(input, filter), window={size=2x2}, dim_labels=b01f_01io->b01f, custom_call_target="__cudnn$convForward" - })"); - EXPECT_TRUE(PadForTensorCores().Run(&module()).ValueOrDie()); - auto* root = module().entry_computation()->root_instruction(); + })") + .ValueOrDie(); + EXPECT_TRUE(CudnnConvPadForTensorCores().Run(module.get()).ValueOrDie()); + auto* root = module->entry_computation()->root_instruction(); - SCOPED_TRACE(module().ToString()); + SCOPED_TRACE(module->ToString()); EXPECT_THAT(root, op::CustomCall(kCudnnConvForwardCallTarget, op::Pad(op::Parameter(0), _), op::Pad(op::Parameter(1), _))); @@ -60,8 +56,8 @@ TEST_F(PadForTensorCoresTest, PadF16ForwardConvInputChannels) { ShapeUtil::MakeShape(F16, {2, 2, 48, 40}))); } -TEST_F(PadForTensorCoresTest, PadF16BackwardInputConvOutputChannels) { - ParseAndVerifyModule(R"( +TEST_F(CudnnConvPadForTensorCoresTest, PadF16BackwardInputConvOutputChannels) { + auto module = ParseAndReturnVerifiedModule(R"( HloModule TestModule ENTRY TestComputation { @@ -70,9 +66,10 @@ TEST_F(PadForTensorCoresTest, PadF16BackwardInputConvOutputChannels) { ROOT result = (f16[10,20,30,40], u8[0]) custom-call(output, filter), window={size=2x2}, dim_labels=b01f_01io->b01f, custom_call_target="__cudnn$convBackwardInput" - })"); - EXPECT_TRUE(PadForTensorCores().Run(&module()).ValueOrDie()); - auto* root = module().entry_computation()->root_instruction(); + })") + .ValueOrDie(); + EXPECT_TRUE(CudnnConvPadForTensorCores().Run(module.get()).ValueOrDie()); + auto* root = module->entry_computation()->root_instruction(); EXPECT_THAT(root, op::CustomCall(kCudnnConvBackwardInputCallTarget, op::Pad(op::Parameter(0), _), op::Pad(op::Parameter(1), _))); @@ -82,8 +79,8 @@ TEST_F(PadForTensorCoresTest, PadF16BackwardInputConvOutputChannels) { ShapeUtil::MakeShape(F16, {2, 2, 40, 48}))); } -TEST_F(PadForTensorCoresTest, PadF16ForwardConvOutputChannels) { - ParseAndVerifyModule(R"( +TEST_F(CudnnConvPadForTensorCoresTest, PadF16ForwardConvOutputChannels) { + auto module = ParseAndReturnVerifiedModule(R"( HloModule TestModule ENTRY TestComputation { @@ -92,17 +89,18 @@ TEST_F(PadForTensorCoresTest, PadF16ForwardConvOutputChannels) { ROOT result = (f16[10,20,30,41], u8[0]) custom-call(input, filter), window={size=2x2}, dim_labels=b01f_01io->b01f, custom_call_target="__cudnn$convForward" - })"); - EXPECT_TRUE(PadForTensorCores().Run(&module()).ValueOrDie()); - auto* root = module().entry_computation()->root_instruction(); + })") + .ValueOrDie(); + EXPECT_TRUE(CudnnConvPadForTensorCores().Run(module.get()).ValueOrDie()); + auto* root = module->entry_computation()->root_instruction(); EXPECT_THAT(root, op::Tuple(op::Slice(op::GetTupleElement(op::CustomCall( kCudnnConvForwardCallTarget, op::Parameter(0), op::Pad(op::Parameter(1), _)))), _)); } -TEST_F(PadForTensorCoresTest, PadF16BackwardInputConvInputChannels) { - ParseAndVerifyModule(R"( +TEST_F(CudnnConvPadForTensorCoresTest, PadF16BackwardInputConvInputChannels) { + auto module = ParseAndReturnVerifiedModule(R"( HloModule TestModule ENTRY TestComputation { @@ -112,9 +110,10 @@ TEST_F(PadForTensorCoresTest, PadF16BackwardInputConvInputChannels) { window={size=2x2}, dim_labels=b01f_01io->b01f, custom_call_target="__cudnn$convBackwardInput" ROOT gte = f16[10,20,30,41] get-tuple-element(result), index=0 - })"); - EXPECT_TRUE(PadForTensorCores().Run(&module()).ValueOrDie()); - auto* root = module().entry_computation()->root_instruction(); + })") + .ValueOrDie(); + EXPECT_TRUE(CudnnConvPadForTensorCores().Run(module.get()).ValueOrDie()); + auto* root = module->entry_computation()->root_instruction(); EXPECT_THAT(root, op::GetTupleElement(op::Tuple( op::Slice(op::GetTupleElement(op::CustomCall( kCudnnConvBackwardInputCallTarget, op::Parameter(0), @@ -122,8 +121,8 @@ TEST_F(PadForTensorCoresTest, PadF16BackwardInputConvInputChannels) { _))); } -TEST_F(PadForTensorCoresTest, PadF16BackwardFilterConvInputChannels) { - ParseAndVerifyModule(R"( +TEST_F(CudnnConvPadForTensorCoresTest, PadF16BackwardFilterConvInputChannels) { + auto module = ParseAndReturnVerifiedModule(R"( HloModule TestModule ENTRY TestComputation { @@ -133,9 +132,10 @@ TEST_F(PadForTensorCoresTest, PadF16BackwardFilterConvInputChannels) { window={size=2x2}, dim_labels=b01f_01io->b01f, custom_call_target="__cudnn$convBackwardFilter" ROOT gte = f16[2,2,41,40] get-tuple-element(result), index=0 - })"); - EXPECT_TRUE(PadForTensorCores().Run(&module()).ValueOrDie()); - auto* root = module().entry_computation()->root_instruction(); + })") + .ValueOrDie(); + EXPECT_TRUE(CudnnConvPadForTensorCores().Run(module.get()).ValueOrDie()); + auto* root = module->entry_computation()->root_instruction(); EXPECT_THAT(root, op::GetTupleElement(op::Tuple( op::Slice(op::GetTupleElement(op::CustomCall( kCudnnConvBackwardFilterCallTarget, @@ -143,8 +143,8 @@ TEST_F(PadForTensorCoresTest, PadF16BackwardFilterConvInputChannels) { _))); } -TEST_F(PadForTensorCoresTest, PadF16BackwardFilterConvOutputChannels) { - ParseAndVerifyModule(R"( +TEST_F(CudnnConvPadForTensorCoresTest, PadF16BackwardFilterConvOutputChannels) { + auto module = ParseAndReturnVerifiedModule(R"( HloModule TestModule ENTRY TestComputation { @@ -154,9 +154,10 @@ TEST_F(PadForTensorCoresTest, PadF16BackwardFilterConvOutputChannels) { window={size=2x2}, dim_labels=b01f_01io->b01f, custom_call_target="__cudnn$convBackwardFilter" ROOT gte = f16[2,2,40,41] get-tuple-element(result), index=0 - })"); - EXPECT_TRUE(PadForTensorCores().Run(&module()).ValueOrDie()); - auto* root = module().entry_computation()->root_instruction(); + })") + .ValueOrDie(); + EXPECT_TRUE(CudnnConvPadForTensorCores().Run(module.get()).ValueOrDie()); + auto* root = module->entry_computation()->root_instruction(); EXPECT_THAT(root, op::GetTupleElement(op::Tuple( op::Slice(op::GetTupleElement(op::CustomCall( kCudnnConvBackwardFilterCallTarget, @@ -164,6 +165,31 @@ TEST_F(PadForTensorCoresTest, PadF16BackwardFilterConvOutputChannels) { _))); } +TEST_F(CudnnConvPadForTensorCoresTest, PadInputFeatures3To4) { + auto module = ParseAndReturnVerifiedModule(R"( + HloModule TestModule + + ENTRY TestComputation { + input = f16[10,20,30,3] parameter(0) + filter = f16[2,2,3,32] parameter(1) + ROOT result = (f16[10,20,30,32], u8[0]) custom-call(input, filter), + window={size=2x2}, dim_labels=b01f_01io->b01f, + custom_call_target="__cudnn$convForward" + })") + .ValueOrDie(); + EXPECT_TRUE(CudnnConvPadForTensorCores().Run(module.get()).ValueOrDie()); + auto* root = module->entry_computation()->root_instruction(); + + SCOPED_TRACE(module->ToString()); + EXPECT_THAT(root, op::CustomCall(kCudnnConvForwardCallTarget, + op::Pad(op::Parameter(0), _), + op::Pad(op::Parameter(1), _))); + EXPECT_TRUE(ShapeUtil::Equal(root->operand(0)->shape(), + ShapeUtil::MakeShape(F16, {10, 20, 30, 4}))); + EXPECT_TRUE(ShapeUtil::Equal(root->operand(1)->shape(), + ShapeUtil::MakeShape(F16, {2, 2, 4, 32}))); +} + } // anonymous namespace } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc b/tensorflow/compiler/xla/service/gpu/cudnn_conv_padding_legalization.cc similarity index 86% rename from tensorflow/compiler/xla/service/gpu/pad_insertion.cc rename to tensorflow/compiler/xla/service/gpu/cudnn_conv_padding_legalization.cc index 98cc21ccac57268257f1f9a3999a3d876ef074fc..d7829045cc127deaa4c2c9b705dca5285d704af2 100644 --- a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_padding_legalization.cc @@ -13,12 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/gpu/pad_insertion.h" +#include "tensorflow/compiler/xla/service/gpu/cudnn_conv_padding_legalization.h" #include "absl/memory/memory.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_creation_utils.h" #include "tensorflow/compiler/xla/service/shape_inference.h" #include "tensorflow/compiler/xla/util.h" @@ -30,7 +31,8 @@ namespace gpu { namespace { bool IsForwardConvolutionCanonical(const HloInstruction& conv) { - CHECK_EQ(conv.custom_call_target(), kCudnnConvForwardCallTarget); + CHECK(conv.custom_call_target() == kCudnnConvForwardCallTarget || + conv.custom_call_target() == kCudnnConvBiasActivationForwardCallTarget); return window_util::HasSymmetricPadding(conv.window()) && !window_util::HasNegativePadding(conv.window()) && !window_util::HasDilation(conv.window()); @@ -68,9 +70,8 @@ HloInstruction* MaybePaddedAndSlicedInput( conv_window.dimensions(i).base_dilation() - 1); } PrimitiveType element_type = input->shape().element_type(); - HloInstruction* padding = - computation->AddInstruction(HloInstruction::CreateConstant( - absl::make_unique(LiteralUtil::Zero(element_type)))); + HloInstruction* padding = computation->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::Zero(element_type))); input = MakePadHlo(input, padding, padding_config).ValueOrDie(); } @@ -125,14 +126,14 @@ HloInstruction* MaybePaddedKernel(const Window& conv_window, HloComputation* computation = kernel->parent(); PrimitiveType element_type = kernel->shape().element_type(); - HloInstruction* padding = - computation->AddInstruction(HloInstruction::CreateConstant( - absl::make_unique(LiteralUtil::Zero(element_type)))); + HloInstruction* padding = computation->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::Zero(element_type))); return MakePadHlo(kernel, padding, padding_config).ValueOrDie(); } } // namespace -bool PadInsertion::CanonicalizeForwardConvolution(HloInstruction* conv) { +bool CudnnConvPaddingLegalization::CanonicalizeForwardConvolution( + HloInstruction* conv) { if (IsForwardConvolutionCanonical(*conv)) { return false; } @@ -163,12 +164,14 @@ bool PadInsertion::CanonicalizeForwardConvolution(HloInstruction* conv) { // The conv CustomCall returns a tuple (conv_result, scratch_buffer). Extract // out the shape of conv_result. - Shape old_conv_shape = conv->shape().tuple_shapes(0); - VLOG(1) << "Canonicalizing forward conv"; - auto new_conv = CreateCudnnConvForward(old_conv_shape, new_input, new_kernel, - new_conv_window, - conv->convolution_dimension_numbers()); + std::vector operands(conv->operands().begin(), + conv->operands().end()); + operands[0] = new_input; + operands[1] = new_kernel; + auto new_conv = conv->parent()->AddInstruction( + conv->CloneWithNewOperands(conv->shape(), operands)); + new_conv->set_window(new_conv_window); VLOG(1) << "Replacing:\n " << conv->ToString() << "\nwith:\n " << new_conv->ToString(); TF_CHECK_OK(conv->parent()->ReplaceInstruction(conv, new_conv)); @@ -185,7 +188,7 @@ void IncreasePaddingHighBy(int64 delta, WindowDimension* window_dim) { } } // namespace -bool PadInsertion::CanonicalizeBackwardFilterConvolution( +bool CudnnConvPaddingLegalization::CanonicalizeBackwardFilterConvolution( HloInstruction* backward_conv) { CHECK_EQ(backward_conv->custom_call_target(), kCudnnConvBackwardFilterCallTarget); @@ -236,18 +239,18 @@ bool PadInsertion::CanonicalizeBackwardFilterConvolution( // Create a new backward convolution replacing the old one. HloComputation* computation = backward_conv->parent(); HloInstruction* output = backward_conv->mutable_operand(1); - HloInstruction* padding = computation->AddInstruction( - HloInstruction::CreateConstant(absl::make_unique( - LiteralUtil::Zero(input->shape().element_type())))); + HloInstruction* padding = + computation->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(input->shape().element_type()))); HloInstruction* padded_input = MakePadHlo(input, padding, input_padding_config).ValueOrDie(); // The shape of the backward_conv CustomCall is a tuple (conv_result, // scratch_buffer). Extract out the shape of conv_result. - Shape backward_conv_shape = backward_conv->shape().tuple_shapes(0); - HloInstruction* new_backward_conv = CreateCudnnConvBackwardFilter( - backward_conv_shape, padded_input, output, new_backward_conv_window, - backward_conv_dnums); + HloInstruction* new_backward_conv = + computation->AddInstruction(backward_conv->CloneWithNewOperands( + backward_conv->shape(), {padded_input, output})); + new_backward_conv->set_window(new_backward_conv_window); VLOG(1) << "Canonicalizing backward filter conv"; VLOG(1) << "Replacing:\n " << backward_conv->ToString() << "\nwith:\n " @@ -258,7 +261,7 @@ bool PadInsertion::CanonicalizeBackwardFilterConvolution( return true; } -bool PadInsertion::CanonicalizeBackwardInputConvolution( +bool CudnnConvPaddingLegalization::CanonicalizeBackwardInputConvolution( HloInstruction* backward_conv) { if (window_util::HasSymmetricPadding(backward_conv->window())) { return false; @@ -310,9 +313,12 @@ bool PadInsertion::CanonicalizeBackwardInputConvolution( HloInstruction* output = backward_conv->mutable_operand(0); HloInstruction* filter = backward_conv->mutable_operand(1); - HloInstruction* new_backward_conv_call = CreateCudnnConvBackwardInput( - new_backward_conv_shape, output, filter, new_backward_conv_window, - backward_conv_dnums); + HloInstruction* new_backward_conv_call = + computation->AddInstruction(backward_conv->CloneWithNewOperands( + ShapeUtil::MakeTupleShape( + {new_backward_conv_shape, ShapeUtil::MakeShape(U8, {0})}), + {output, filter})); + new_backward_conv_call->set_window(new_backward_conv_window); // The CustomCall created above returns a tuple (conv_result, scratch_memory). // Extract out the two elements. @@ -372,31 +378,33 @@ bool PadInsertion::CanonicalizeBackwardInputConvolution( return true; } -StatusOr PadInsertion::RunOnComputation(HloComputation* computation) { +StatusOr CudnnConvPaddingLegalization::RunOnComputation( + HloComputation* computation) { bool changed = false; - std::vector convs; + std::vector convs; for (auto* instr : computation->instructions()) { if (IsCustomCallToDnnConvolution(*instr)) { - convs.push_back(instr); + convs.push_back(Cast(instr)); } } - for (HloInstruction* instruction : convs) { - const auto& target = instruction->custom_call_target(); - if (target == kCudnnConvForwardCallTarget) { - changed |= CanonicalizeForwardConvolution(instruction); - } else if (target == kCudnnConvBackwardFilterCallTarget) { - changed |= CanonicalizeBackwardFilterConvolution(instruction); - } else if (target == kCudnnConvBackwardInputCallTarget) { - changed |= CanonicalizeBackwardInputConvolution(instruction); - } else { - LOG(FATAL) << "Unknown custom call target for cudnn conv: " - << instruction->ToString(); - } + for (HloCustomCallInstruction* instruction : convs) { + TF_ASSIGN_OR_RETURN(auto kind, GetCudnnConvKind(instruction)); + changed |= [&] { + switch (kind) { + case CudnnConvKind::kForward: + case CudnnConvKind::kForwardActivation: + return CanonicalizeForwardConvolution(instruction); + case CudnnConvKind::kBackwardInput: + return CanonicalizeBackwardInputConvolution(instruction); + case CudnnConvKind::kBackwardFilter: + return CanonicalizeBackwardFilterConvolution(instruction); + } + }(); } return changed; } -StatusOr PadInsertion::Run(HloModule* module) { +StatusOr CudnnConvPaddingLegalization::Run(HloModule* module) { bool changed = false; for (HloComputation* computation : module->MakeNonfusionComputations()) { TF_ASSIGN_OR_RETURN(bool result, RunOnComputation(computation)); diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.h b/tensorflow/compiler/xla/service/gpu/cudnn_conv_padding_legalization.h similarity index 78% rename from tensorflow/compiler/xla/service/gpu/pad_insertion.h rename to tensorflow/compiler/xla/service/gpu/cudnn_conv_padding_legalization.h index a622e894ed9c0d1534262e6b72a5f4ea7b7821ad..7d1b075517fb285222506e0420984906579e681f 100644 --- a/tensorflow/compiler/xla/service/gpu/pad_insertion.h +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_padding_legalization.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_PAD_INSERTION_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_PAD_INSERTION_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONV_PADDING_LEGALIZATION_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONV_PADDING_LEGALIZATION_H_ #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" @@ -24,9 +24,11 @@ namespace gpu { // An HLO pass that canonicalizes convolution instructions for GPU codegen. It // inserts Pad instructions before Convolution instructions with uncanonicalized // padding, so that they can be lowered to cuDNN convolution. -class PadInsertion : public HloPassInterface { +class CudnnConvPaddingLegalization : public HloModulePass { public: - absl::string_view name() const override { return "pad insertion"; } + absl::string_view name() const override { + return "cudnn-conv-padding-legalization"; + } StatusOr Run(HloModule* module) override; @@ -41,4 +43,4 @@ class PadInsertion : public HloPassInterface { } // namespace gpu } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_PAD_INSERTION_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONV_PADDING_LEGALIZATION_H_ diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc b/tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter.cc similarity index 66% rename from tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc rename to tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter.cc index 905b5ee8767d0fa0514c7f1abf83bc089cd08045..5cea66de38c77b7690d9c9485fa0534af30a0ad6 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter.cc @@ -13,13 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h" +#include "tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter.h" +#include #include #include #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -34,6 +36,32 @@ namespace gpu { namespace { +HloInstruction* CreateCudnnConv(const char* call_target, const Shape& shape, + HloInstruction* lhs, HloInstruction* rhs, + const Window& window, + const ConvolutionDimensionNumbers& dnums, + int64 feature_group_count) { + HloComputation* computation = lhs->parent(); + + // This call returns a tuple of (conv_result, scratch_memory), where + // conv_result is the actual result of the convolution, and scratch_memory is + // temporary memory used by cudnn. + // + // At the moment, we don't know how much scratch memory this conv is going to + // use, so we put u8[0] in this place. Later on another pass will choose + // which conv algorithm to use, and at that point we'll modify the shape of + // this second tuple element. + Shape call_shape = + ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U8, {0})}); + + HloInstruction* custom_call = computation->AddInstruction( + HloInstruction::CreateCustomCall(call_shape, {lhs, rhs}, call_target)); + custom_call->set_window(window); + custom_call->set_convolution_dimension_numbers(dnums); + custom_call->set_feature_group_count(feature_group_count); + return custom_call; +} + bool CanImplementAsCudnnForwardConv(HloInstruction* conv) { const ConvolutionDimensionNumbers& dnums = conv->convolution_dimension_numbers(); @@ -59,6 +87,9 @@ std::tuple MatchBackwardFilter( HloInstruction* conv) { const auto no_match_result = std::make_tuple(false, Window(), ConvolutionDimensionNumbers()); + if (conv->feature_group_count() > 1) { + return no_match_result; + } // Step 1: match the instruction pattern without considering the paddings and // dimension numbers just yet. We may need some generic pattern matcher // similar to third_party/llvm/llvm/include/llvm/IR/PatternMatch.h @@ -157,9 +188,9 @@ std::tuple MatchBackwardFilter( // the amount of high padding the same as the amount of low padding as long // as it is between min_padding_high and max_padding_high. If it is not in // that range, we pick the one that's closest to dim->padding_low() and let - // PadInsertion canonicalize the resultant backward convolution later. - // Picking the closest one minimizes the cost of the kPad instruction to be - // inserted by PadInsertion. + // CudnnConvPaddingLegalization canonicalize the resultant backward + // convolution later. Picking the closest one minimizes the cost of the kPad + // instruction to be inserted by CudnnConvPaddingLegalization. if (dim->padding_low() >= min_padding_high && dim->padding_low() <= max_padding_high) { dim->set_padding_high(dim->padding_low()); @@ -176,7 +207,8 @@ std::tuple MatchBackwardFilter( "negative padding (" << dim->padding_high() << ") on right/bottom of the weight gradients, which is not " - "supported by PadInsertion (b/32744257). Falling back to " + "supported by CudnnConvPaddingLegalization (b/32744257). " + "Falling back to " "unfused convolution for instruction: " << conv->ToString(); return no_match_result; @@ -213,42 +245,55 @@ std::tuple MatchBackwardFilter( // Try to match a backward input pattern that contains "conv". // Precondition: "conv" is a kConvolution. -std::tuple MatchBackwardInput( - HloInstruction* conv) { +std::tuple +MatchBackwardInput(HloInstruction* conv) { const auto no_match_result = - std::make_tuple(false, Window(), ConvolutionDimensionNumbers()); + std::make_tuple(false, Window(), ConvolutionDimensionNumbers(), nullptr); + + // TODO(b/31709653): Theoretically cuDNN supports grouped convolutions also + // for the backward input convolution, but at least for now with version 7.1.4 + // it is slower. This needs to be re-evaluated for future cuDNN versions. + // Note that we already have the necessary code down below, the only thing to + // enable it is to remove the following early return. + if (conv->feature_group_count() > 1) { + return no_match_result; + } // Match instruction pattern. CHECK_EQ(HloOpcode::kConvolution, conv->opcode()); HloInstruction* reverse_filter = conv->mutable_operand(1); - - // Match the reverse of the filter. ConvolutionDimensionNumbers dnums = conv->convolution_dimension_numbers(); - const auto& kernel_spatial_dims = dnums.kernel_spatial_dimensions(); - if (reverse_filter->opcode() == HloOpcode::kReverse) { - if (kernel_spatial_dims.size() != reverse_filter->dimensions().size() || - !std::is_permutation(kernel_spatial_dims.begin(), - kernel_spatial_dims.end(), - reverse_filter->dimensions().begin())) { - VLOG(1) - << "Backward input convolution should reverse all kernel dimensions."; - return no_match_result; - } - } else { - // Possibly 1x1 filter. - for (int64 i = 0; i < kernel_spatial_dims.size(); ++i) { - if (conv->window().dimensions(i).size() != 1) { - VLOG(1) << "The reverse filter is neither a kReverse nor a 1x1 filter: " - << reverse_filter->ToString(); - return no_match_result; - } - } - if (!window_util::HasBaseDilation(conv->window())) { - VLOG(1) << conv->ToString() - << " is a regular forward convolution. No need " - "to fold it to a backward input convolution."; - return no_match_result; - } + + // We pattern-match to a backwards input conv if: + // + // - all spatial dims of the filter are reversed + // + // OR + // + // - filter is 1x1 or a constant AND + // - conv has base dilation (otherwise this is just a regular forward conv). + // + // The final criterion above is just for canonicalization; cudnn seems to run + // just as fast if we canonicalize 1x1/constant filters without base dilation + // to forward or backward convs. We canonicalize to forward conv because (a) + // it's more natural (constant filters usually show up when doing inference, + // and having backwards convolutions in inference graphs would be weird), and + // (b) cudnn has special fusions for forward conv plus bias and activation, + // and we want to pattern-match to that after running this pass. + bool is_reversed_filter = + reverse_filter->opcode() == HloOpcode::kReverse && + absl::c_is_permutation(dnums.kernel_spatial_dimensions(), + reverse_filter->dimensions()); + bool is_1x1_filter = + absl::c_all_of(conv->window().dimensions(), + [](const WindowDimension& d) { return d.size() == 1; }); + if (!is_reversed_filter && + !(window_util::HasBaseDilation(conv->window()) && + (reverse_filter->IsConstant() || is_1x1_filter))) { + VLOG(1) << "Can't match to backwards convolution. Either filter is not " + "kReverse, or it's not a base-dilated conv with a 1x1 or " + "constant filter."; + return no_match_result; } // Match padding and dilation of the forward convolution. @@ -298,7 +343,8 @@ std::tuple MatchBackwardInput( LOG(ERROR) << "The low padding of the backward convolution would be negative (" << backward_padding_low - << "), which isn't supported by PadInsertion for now (b/32744257)."; + << "), which isn't supported by CudnnConvPaddingLegalization " + "for now (b/32744257)."; return no_match_result; } dim->set_padding_low(backward_padding_low); @@ -327,8 +373,8 @@ std::tuple MatchBackwardInput( dim->set_padding_high(backward_padding_low); } else { // Otherwise, we choose the amount that's closest to backward_padding_low, - // and PadInsertion will later insert kSlice instructions to enforce even - // padding. + // and CudnnConvPaddingLegalization will later insert kSlice + // instructions to enforce even padding. // // For example, consider the backward convolution pattern // @@ -354,9 +400,9 @@ std::tuple MatchBackwardInput( dim->set_padding_high(max_padding_high); } } - // PadInsertion doesn't handle backward input convolution with negative - // padding for now. So fall back to unfused convolution in case of negative - // padding. For example, + // CudnnConvPaddingLegalization doesn't handle backward input + // convolution with negative padding for now. So fall back to unfused + // convolution in case of negative padding. For example, // ABCD = Conv(abc, reverse(xy), padding_high=2) // could be fused to // ABCD = BackwardInputConv(abc, xy, padding_low=1, padding_high=-1) @@ -366,30 +412,77 @@ std::tuple MatchBackwardInput( "negative padding (" << dim->padding_high() << ") on right/bottom of the activations, which is not " - "supported by PadInsertion (b/32744257). Falling back to " - "unfused convolution for instruction: " + "supported by CudnnConvPaddingLegalization (b/32744257). " + "Falling back to unfused convolution for instruction: " << conv->ToString(); return no_match_result; } } - // Fuse the matched HLOs into a backward convolution instruction. - // - // If the reverse is omitted (for 1x1 filters) in the original pattern, we add - // it back in the fusion instruction so that later passes (such as - // PadInsertion) can handle such fusion instructions easily. - if (reverse_filter->opcode() != HloOpcode::kReverse) { - reverse_filter = reverse_filter->parent()->AddInstruction( - HloInstruction::CreateReverse(reverse_filter->shape(), reverse_filter, - AsInt64Slice(kernel_spatial_dims))); - TF_CHECK_OK(conv->ReplaceOperandWith(/*operand_no=*/1, reverse_filter)); - } + // OK, it's a match! Switch the input feature dimension with the output + // feature dimension. This is the way cuDNN expects it to be. dnums.set_kernel_input_feature_dimension( conv->convolution_dimension_numbers().kernel_output_feature_dimension()); dnums.set_kernel_output_feature_dimension( conv->convolution_dimension_numbers().kernel_input_feature_dimension()); - return std::make_tuple(true, new_window, dnums); + // If we matched against a constant, we need to add a reverse op that can be + // subsumed by the cuDNN call. algebraic-simplifier will later remove any + // unnecessary reverses. + if (reverse_filter->opcode() != HloOpcode::kReverse && + reverse_filter->IsConstant()) { + // Create a double-reverse, which is a nop. + HloComputation* c = conv->parent(); + reverse_filter = c->AddInstruction(HloInstruction::CreateReverse( + reverse_filter->shape(), reverse_filter, + AsInt64Slice(dnums.kernel_spatial_dimensions()))); + reverse_filter = c->AddInstruction(HloInstruction::CreateReverse( + reverse_filter->shape(), reverse_filter, + AsInt64Slice(dnums.kernel_spatial_dimensions()))); + TF_CHECK_OK(conv->ReplaceOperandWith(/*operand_no=*/1, reverse_filter)); + } + + // Calculate the 'rhs' that goes into the backward input convolution. + HloInstruction* rhs = reverse_filter; + // One reverse is subsumed by the cuDNN call. + if (rhs->opcode() == HloOpcode::kReverse) { + rhs = rhs->mutable_operand(0); + } + if (conv->feature_group_count() == 1) { + return std::make_tuple(true, new_window, dnums, rhs); + } + + // Handle grouped convolutions. Because we swapped the input feature dimension + // with the output feature dimension, we need to also reshape the kernel so + // that the 'feature_group_count' parameter still makes sense. The + // 'feature_group_count' parameter essentially specifies how often the + // 'kernel_input_feature_dimension' is repeated. So when we swap these + // dimensions, we need to divide the new 'kernel_input_feature_dimension' by + // 'feature_group_count' and multiply the new + // 'kernel_output_feature_dimension' by 'feature_group_count'. + Shape new_shape = rhs->shape(); + int64 input_feature_dimension = dnums.kernel_input_feature_dimension(); + int64 output_feature_dimension = dnums.kernel_output_feature_dimension(); + + // In the backward convolution case, the spatial dimensions become the + // feature dimensions, and we are guaranteed that the spatial dimensions are + // adjacent. + CHECK_EQ(std::abs(input_feature_dimension - output_feature_dimension), 1LL); + int64 input_features = new_shape.dimensions(input_feature_dimension); + int64 output_features = new_shape.dimensions(output_feature_dimension); + new_shape.set_dimensions(input_feature_dimension, + input_features / conv->feature_group_count()); + new_shape.set_dimensions(output_feature_dimension, + output_features * conv->feature_group_count()); + HloComputation* c = conv->parent(); + rhs = c->AddInstruction(HloInstruction::CreateReshape(new_shape, rhs)); + return std::make_tuple(true, new_window, dnums, rhs); +} + +CudnnConvBackendConfig GetDefaultBackendConfig() { + CudnnConvBackendConfig config; + config.set_conv_result_scale(1); + return config; } // Tries to rewrite a single convolution into a call to cudnn. @@ -400,30 +493,28 @@ StatusOr RunOnInstruction(HloInstruction* conv) { bool match; Window window; ConvolutionDimensionNumbers dnums; + HloInstruction* rhs; std::tie(match, window, dnums) = MatchBackwardFilter(conv); if (match) { - return CreateCudnnConvBackwardFilter( - conv->shape(), conv->mutable_operand(0), conv->mutable_operand(1), - window, dnums); + return CreateCudnnConv(kCudnnConvBackwardFilterCallTarget, conv->shape(), + conv->mutable_operand(0), conv->mutable_operand(1), + window, dnums, conv->feature_group_count()); } - std::tie(match, window, dnums) = MatchBackwardInput(conv); + std::tie(match, window, dnums, rhs) = MatchBackwardInput(conv); if (match) { - // Backward input conv subsumes the conv plus the reverse in operand 1. - HloInstruction* reverse = conv->mutable_operand(1); - CHECK_EQ(reverse->opcode(), HloOpcode::kReverse); - HloInstruction* rhs = reverse->mutable_operand(0); - - return CreateCudnnConvBackwardInput( - conv->shape(), conv->mutable_operand(0), rhs, window, dnums); + return CreateCudnnConv(kCudnnConvBackwardInputCallTarget, conv->shape(), + conv->mutable_operand(0), rhs, window, dnums, + conv->feature_group_count()); } // If all else fails, try a forward convolution. if (CanImplementAsCudnnForwardConv(conv)) { - return CreateCudnnConvForward(conv->shape(), conv->mutable_operand(0), - conv->mutable_operand(1), conv->window(), - conv->convolution_dimension_numbers()); + return CreateCudnnConv( + kCudnnConvForwardCallTarget, conv->shape(), conv->mutable_operand(0), + conv->mutable_operand(1), conv->window(), + conv->convolution_dimension_numbers(), conv->feature_group_count()); } return nullptr; @@ -433,6 +524,12 @@ StatusOr RunOnInstruction(HloInstruction* conv) { return false; } + TF_RETURN_IF_ERROR( + custom_call->set_backend_config(GetDefaultBackendConfig())); + + VLOG(1) << "Replacing convolution " << conv->ToString() << " with " + << custom_call->ToString(); + // The CustomCall returns a tuple (conv_result, scratch_memory). Extract out // the conv result and replace `conv` with it. TF_RETURN_IF_ERROR(conv->parent()->ReplaceWithNewInstruction( @@ -460,7 +557,7 @@ StatusOr RunOnComputation(HloComputation* computation) { } } // namespace -StatusOr CudnnConvolutionRewriter::Run(HloModule* module) { +StatusOr CudnnConvRewriter::Run(HloModule* module) { bool changed = false; for (HloComputation* computation : module->MakeNonfusionComputations()) { TF_ASSIGN_OR_RETURN(bool result, RunOnComputation(computation)); diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h b/tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter.h similarity index 74% rename from tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h rename to tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter.h index fbe7e9849458e9d52be15b3f5610479ab68ffa4c..d8ec72c27bab8912d0dc2aeead114eb010b87b78 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_REWRITER_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_REWRITER_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONV_REWRITER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONV_REWRITER_H_ #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" @@ -24,11 +24,9 @@ namespace gpu { // Rewrites plain convolutions, backwards-filter convolutions, and // backwards-input convolutions into CustomCall HLOs that call into cuDNN. -class CudnnConvolutionRewriter : public HloPassInterface { +class CudnnConvRewriter : public HloModulePass { public: - absl::string_view name() const override { - return "cudnn-convolution-rewriter"; - } + absl::string_view name() const override { return "cudnn-conv-rewriter"; } StatusOr Run(HloModule* module) override; }; @@ -36,4 +34,4 @@ class CudnnConvolutionRewriter : public HloPassInterface { } // namespace gpu } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_REWRITER_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONV_REWRITER_H_ diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc b/tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter_test.cc similarity index 78% rename from tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc rename to tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter_test.cc index 65588b6aaf24da628ea586eb52c462b78b8daaa7..543160df8ba477126402c607de2989c04c69725e 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h" +#include "tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -24,7 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/shape_inference.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/core/platform/test.h" namespace xla { @@ -32,10 +32,13 @@ namespace gpu { namespace { namespace op = xla::testing::opcode_matchers; +using ::testing::_; -class CudnnConvolutionRewriterTest : public HloTestBase { +class CudnnConvRewriterTest : public HloVerifiedTestBase { public: - CudnnConvolutionRewriterTest() { + CudnnConvRewriterTest() + : HloVerifiedTestBase(/*layout_sensitive=*/true, + /*allow_mixed_precision=*/false) { for (int i = 0; i < 2; ++i) { WindowDimension* window_dim = default_conv_window_.add_dimensions(); window_dim->set_size(1); @@ -82,7 +85,7 @@ class CudnnConvolutionRewriterTest : public HloTestBase { protected: bool RunPass(HloModule* module) { - return CudnnConvolutionRewriter().Run(module).ValueOrDie(); + return CudnnConvRewriter().Run(module).ValueOrDie(); } // A convolution window with stride 1 and zero padding. The size fields are @@ -92,7 +95,7 @@ class CudnnConvolutionRewriterTest : public HloTestBase { ConvolutionDimensionNumbers tf_default_dnums_for_backward_input_; }; -TEST_F(CudnnConvolutionRewriterTest, BackwardFilterConvolve) { +TEST_F(CudnnConvRewriterTest, BackwardFilterConvolve) { HloComputation::Builder builder(TestName()); HloInstruction* activations = builder.AddInstruction(HloInstruction::CreateParameter( @@ -104,23 +107,23 @@ TEST_F(CudnnConvolutionRewriterTest, BackwardFilterConvolve) { conv_window.mutable_dimensions(1)->set_size(2); conv_window.mutable_dimensions(1)->set_window_dilation(2); builder.AddInstruction(HloInstruction::CreateConvolve( - ShapeInference::InferConvolveShape(activations->shape(), - gradients->shape(), conv_window, - tf_default_dnums_for_backward_filter_) + ShapeInference::InferConvolveShape( + activations->shape(), gradients->shape(), /*feature_group_count=*/1, + conv_window, tf_default_dnums_for_backward_filter_) .ConsumeValueOrDie(), - activations, gradients, conv_window, - tf_default_dnums_for_backward_filter_)); + activations, gradients, /*feature_group_count=*/1, conv_window, + tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2))); auto module = CreateNewModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(RunPass(module.get())); + EXPECT_TRUE(RunPass(module)); EXPECT_THAT(entry_computation->root_instruction(), op::GetTupleElement( op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0)); } -TEST_F(CudnnConvolutionRewriterTest, +TEST_F(CudnnConvRewriterTest, BackwardFilterConvolveEquivalentToForwardConvolution) { HloComputation::Builder builder(TestName()); HloInstruction* activations = @@ -132,25 +135,24 @@ TEST_F(CudnnConvolutionRewriterTest, Window conv_window = default_conv_window_; conv_window.mutable_dimensions(1)->set_size(3); builder.AddInstruction(HloInstruction::CreateConvolve( - ShapeInference::InferConvolveShape(activations->shape(), - gradients->shape(), conv_window, - tf_default_dnums_for_backward_filter_) + ShapeInference::InferConvolveShape( + activations->shape(), gradients->shape(), /*feature_group_count=*/1, + conv_window, tf_default_dnums_for_backward_filter_) .ConsumeValueOrDie(), - activations, gradients, conv_window, - tf_default_dnums_for_backward_filter_)); + activations, gradients, /*feature_group_count=*/1, conv_window, + tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2))); auto module = CreateNewModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(RunPass(module.get())); + EXPECT_TRUE(RunPass(module)); EXPECT_THAT(entry_computation->root_instruction(), op::GetTupleElement( op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0)); } // Extracted from block35 training. -TEST_F(CudnnConvolutionRewriterTest, - BackwardFilterConvolveWithPaddedActivations) { +TEST_F(CudnnConvRewriterTest, BackwardFilterConvolveWithPaddedActivations) { auto builder = HloComputation::Builder(TestName()); HloInstruction* activations = builder.AddInstruction(HloInstruction::CreateParameter( @@ -167,20 +169,20 @@ TEST_F(CudnnConvolutionRewriterTest, } builder.AddInstruction(HloInstruction::CreateConvolve( ShapeUtil::MakeShape(F32, {32, 3, 3, 32}), activations, gradients, - conv_window, tf_default_dnums_for_backward_filter_)); + /*feature_group_count=*/1, conv_window, + tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2))); auto module = CreateNewModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(RunPass(module.get())); + EXPECT_TRUE(RunPass(module)); EXPECT_THAT(entry_computation->root_instruction(), op::GetTupleElement( op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0)); } // Extracted from inception v3 training. -TEST_F(CudnnConvolutionRewriterTest, - BackwardFilterConvolveWithPaddedGradients) { +TEST_F(CudnnConvRewriterTest, BackwardFilterConvolveWithPaddedGradients) { auto builder = HloComputation::Builder(TestName()); HloInstruction* activations = builder.AddInstruction(HloInstruction::CreateParameter( @@ -197,18 +199,19 @@ TEST_F(CudnnConvolutionRewriterTest, } builder.AddInstruction(HloInstruction::CreateConvolve( ShapeUtil::MakeShape(F32, {320, 3, 3, 192}), activations, gradients, - conv_window, tf_default_dnums_for_backward_filter_)); + /*feature_group_count=*/1, conv_window, + tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2))); auto module = CreateNewModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(RunPass(module.get())); + EXPECT_TRUE(RunPass(module)); EXPECT_THAT(entry_computation->root_instruction(), op::GetTupleElement( op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0)); } -TEST_F(CudnnConvolutionRewriterTest, BackwardFilterConvolveWithUnevenPadding) { +TEST_F(CudnnConvRewriterTest, BackwardFilterConvolveWithUnevenPadding) { auto builder = HloComputation::Builder(TestName()); HloInstruction* activations = builder.AddInstruction(HloInstruction::CreateParameter( @@ -225,18 +228,19 @@ TEST_F(CudnnConvolutionRewriterTest, BackwardFilterConvolveWithUnevenPadding) { } builder.AddInstruction(HloInstruction::CreateConvolve( ShapeUtil::MakeShape(F32, {32, 2, 2, 32}), activations, gradients, - conv_window, tf_default_dnums_for_backward_filter_)); + /*feature_group_count=*/1, conv_window, + tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2))); auto module = CreateNewModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(RunPass(module.get())); + EXPECT_TRUE(RunPass(module)); EXPECT_THAT(entry_computation->root_instruction(), op::GetTupleElement( op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0)); } -TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolveEvenPadding) { +TEST_F(CudnnConvRewriterTest, BackwardInputConvolveEvenPadding) { auto builder = HloComputation::Builder(TestName()); HloInstruction* output = builder.AddInstruction(HloInstruction::CreateParameter( @@ -269,18 +273,19 @@ TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolveEvenPadding) { HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( ShapeUtil::MakeShape(F32, {4, 3, 16, 16}), /*lhs=*/output, - /*rhs=*/reverse_kernel, conv_window, conv_dnums)); + /*rhs=*/reverse_kernel, /*feature_group_count=*/1, conv_window, + conv_dnums, DefaultPrecisionConfig(2))); // Verify the convolution's shape is consistent with ShapeInference. CHECK(ShapeUtil::Compatible( - conv->shape(), - ShapeInference::InferConvolveShape( - output->shape(), reverse_kernel->shape(), conv_window, conv_dnums) - .ValueOrDie())); + conv->shape(), ShapeInference::InferConvolveShape( + output->shape(), reverse_kernel->shape(), + /*feature_group_count=*/1, conv_window, conv_dnums) + .ValueOrDie())); auto module = CreateNewModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(RunPass(module.get())); + EXPECT_TRUE(RunPass(module)); ASSERT_THAT(entry_computation->root_instruction(), op::GetTupleElement( @@ -300,7 +305,7 @@ TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolveEvenPadding) { // Convolve([abc], [x], base_dilation=2) // = Convolve([abc], Reverse([x]), base_dilation=2) // = BackwardInputConvolve([abc], [x], stride=2) -TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolve1x1Filter) { +TEST_F(CudnnConvRewriterTest, BackwardInputConvolve1x1Filter) { auto builder = HloComputation::Builder(TestName()); // NHWC dimension order. HloInstruction* output = @@ -316,16 +321,16 @@ TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolve1x1Filter) { builder.AddInstruction(HloInstruction::CreateConvolve( ShapeInference::InferConvolveShape(output->shape(), kernel->shape(), - conv_window, + /*feature_group_count=*/1, conv_window, tf_default_dnums_for_backward_input_) .ConsumeValueOrDie(), - /*lhs=*/output, /*rhs=*/kernel, conv_window, - tf_default_dnums_for_backward_input_)); + /*lhs=*/output, /*rhs=*/kernel, /*feature_group_count=*/1, conv_window, + tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2))); auto module = CreateNewModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(RunPass(module.get())); + EXPECT_TRUE(RunPass(module)); EXPECT_THAT(entry_computation->root_instruction(), op::GetTupleElement( op::CustomCall(kCudnnConvBackwardInputCallTarget), 0)); @@ -334,7 +339,7 @@ TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolve1x1Filter) { // BackwardInputConvolve([abc], [x], stride=1) is equivalent to // ForwardConvolve([abc], [x], stride=1). No need to fold it into backward input // convolution. -TEST_F(CudnnConvolutionRewriterTest, +TEST_F(CudnnConvRewriterTest, BackwardInputConvolve1x1FilterEquivalentToForwardConvolve) { auto builder = HloComputation::Builder(TestName()); // NHWC dimension order. @@ -347,17 +352,18 @@ TEST_F(CudnnConvolutionRewriterTest, 1, ShapeUtil::MakeShape(F32, {1, 1, 1, 1}), "kernel")); builder.AddInstruction(HloInstruction::CreateConvolve( - ShapeInference::InferConvolveShape(output->shape(), kernel->shape(), - default_conv_window_, - tf_default_dnums_for_backward_input_) + ShapeInference::InferConvolveShape( + output->shape(), kernel->shape(), /*feature_group_count=*/1, + default_conv_window_, tf_default_dnums_for_backward_input_) .ConsumeValueOrDie(), - /*lhs=*/output, /*rhs=*/kernel, default_conv_window_, - tf_default_dnums_for_backward_input_)); + /*lhs=*/output, /*rhs=*/kernel, /*feature_group_count=*/1, + default_conv_window_, tf_default_dnums_for_backward_input_, + DefaultPrecisionConfig(2))); auto module = CreateNewModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(RunPass(module.get())); + EXPECT_TRUE(RunPass(module)); EXPECT_THAT( entry_computation->root_instruction(), op::GetTupleElement(op::CustomCall(kCudnnConvForwardCallTarget), 0)); @@ -377,8 +383,7 @@ TEST_F(CudnnConvolutionRewriterTest, // 20x10x10x192 // // Gradients are padded unevenly. -TEST_F(CudnnConvolutionRewriterTest, - BackwardInputConvolveUnevenPaddingOnGradients) { +TEST_F(CudnnConvRewriterTest, BackwardInputConvolveUnevenPaddingOnGradients) { auto builder = HloComputation::Builder(TestName()); HloInstruction* output = builder.AddInstruction(HloInstruction::CreateParameter( @@ -399,18 +404,20 @@ TEST_F(CudnnConvolutionRewriterTest, } HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( ShapeUtil::MakeShape(F32, {20, 10, 10, 192}), output, reverse_kernel, - conv_window, tf_default_dnums_for_backward_input_)); + /*feature_group_count=*/1, conv_window, + tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2))); // Verify the convolution's shape is consistent with ShapeInference. CHECK(ShapeUtil::Compatible( - conv->shape(), ShapeInference::InferConvolveShape( - output->shape(), reverse_kernel->shape(), conv_window, - tf_default_dnums_for_backward_input_) - .ValueOrDie())); + conv->shape(), + ShapeInference::InferConvolveShape( + output->shape(), reverse_kernel->shape(), /*feature_group_count=*/1, + conv_window, tf_default_dnums_for_backward_input_) + .ValueOrDie())); auto module = CreateNewModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(RunPass(module.get())); + EXPECT_TRUE(RunPass(module)); ASSERT_THAT(entry_computation->root_instruction(), op::GetTupleElement( op::CustomCall(kCudnnConvBackwardInputCallTarget), 0)); @@ -426,7 +433,7 @@ TEST_F(CudnnConvolutionRewriterTest, // Similar to BackwardInputConvolveUnevenPadding, but the low padding of the // gradients exceeds kernel_size - 1. Therefore, this pattern cannot be fused. -TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolveLowPaddingTooLarge) { +TEST_F(CudnnConvRewriterTest, BackwardInputConvolveLowPaddingTooLarge) { auto builder = HloComputation::Builder(TestName()); HloInstruction* output = builder.AddInstruction(HloInstruction::CreateParameter( @@ -446,18 +453,20 @@ TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolveLowPaddingTooLarge) { } HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( ShapeUtil::MakeShape(F32, {20, 10, 10, 192}), output, reverse_kernel, - conv_window, tf_default_dnums_for_backward_input_)); + /*feature_group_count=*/1, conv_window, + tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2))); // Verify the convolution's shape is consistent with ShapeInference. CHECK(ShapeUtil::Compatible( - conv->shape(), ShapeInference::InferConvolveShape( - output->shape(), reverse_kernel->shape(), conv_window, - tf_default_dnums_for_backward_input_) - .ValueOrDie())); + conv->shape(), + ShapeInference::InferConvolveShape( + output->shape(), reverse_kernel->shape(), /*feature_group_count=*/1, + conv_window, tf_default_dnums_for_backward_input_) + .ValueOrDie())); auto module = CreateNewModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(RunPass(module.get())); + EXPECT_TRUE(RunPass(module)); EXPECT_THAT( entry_computation->root_instruction(), op::GetTupleElement(op::CustomCall(kCudnnConvForwardCallTarget), 0)); @@ -476,9 +485,8 @@ TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolveLowPaddingTooLarge) { // padding_low=2, padding_high=1, base_dilation=2) // // We should fuse BC even though padding on activations is uneven, because -// PadInsertion will canonicalize the fusion HLO. -TEST_F(CudnnConvolutionRewriterTest, - BackwardInputConvolveUnevenPaddingOnActivations) { +// CudnnConvPaddingLegalization will canonicalize the fusion HLO. +TEST_F(CudnnConvRewriterTest, BackwardInputConvolveUnevenPaddingOnActivations) { auto builder = HloComputation::Builder(TestName()); // The gradients are in NCHW layout. HloInstruction* output = @@ -499,18 +507,20 @@ TEST_F(CudnnConvolutionRewriterTest, forward_conv_col_dim->set_base_dilation(2); HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( ShapeUtil::MakeShape(F32, {1, 1, 14, 1}), output, reverse_kernel, - conv_window, tf_default_dnums_for_backward_input_)); + /*feature_group_count=*/1, conv_window, + tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2))); // Verify the convolution's shape is consistent with ShapeInference. CHECK(ShapeUtil::Compatible( - conv->shape(), ShapeInference::InferConvolveShape( - output->shape(), reverse_kernel->shape(), conv_window, - tf_default_dnums_for_backward_input_) - .ValueOrDie())); + conv->shape(), + ShapeInference::InferConvolveShape( + output->shape(), reverse_kernel->shape(), /*feature_group_count=*/1, + conv_window, tf_default_dnums_for_backward_input_) + .ValueOrDie())); auto module = CreateNewModule(); const HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(RunPass(module.get())); + EXPECT_TRUE(RunPass(module)); ASSERT_THAT(entry_computation->root_instruction(), op::GetTupleElement( op::CustomCall(kCudnnConvBackwardInputCallTarget), 0)); @@ -529,9 +539,10 @@ TEST_F(CudnnConvolutionRewriterTest, // BC = BackwardInput(FC) does: // [4] = conv([3], reverse([2]), padding_high=2) // -// We currently don't fuse BC because PadInsertion doesn't support negative -// padding on the gradients of backward convolution (b/32744257). -TEST_F(CudnnConvolutionRewriterTest, +// We currently don't fuse BC because CudnnConvPaddingLegalization +// doesn't support negative padding on the gradients of backward convolution +// (b/32744257). +TEST_F(CudnnConvRewriterTest, BackwardInputConvolveNegativePaddingHighOnActivations) { auto builder = HloComputation::Builder(TestName()); // The gradients are in NCHW layout. @@ -551,23 +562,51 @@ TEST_F(CudnnConvolutionRewriterTest, forward_conv_col_dim->set_padding_high(2); HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( ShapeUtil::MakeShape(F32, {1, 1, 4, 1}), output, reverse_kernel, - conv_window, tf_default_dnums_for_backward_input_)); + /*feature_group_count=*/1, conv_window, + tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2))); // Verify the convolution's shape is consistent with ShapeInference. CHECK(ShapeUtil::Compatible( - conv->shape(), ShapeInference::InferConvolveShape( - output->shape(), reverse_kernel->shape(), conv_window, - tf_default_dnums_for_backward_input_) - .ValueOrDie())); + conv->shape(), + ShapeInference::InferConvolveShape( + output->shape(), reverse_kernel->shape(), /*feature_group_count=*/1, + conv_window, tf_default_dnums_for_backward_input_) + .ValueOrDie())); auto module = CreateNewModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(RunPass(module.get())); + EXPECT_TRUE(RunPass(module)); EXPECT_THAT( entry_computation->root_instruction(), op::GetTupleElement(op::CustomCall(kCudnnConvForwardCallTarget), 0)); } +// Check that we will materialize a reversed version of a constant in order to +// pattern-match a backwards input convolution. +TEST_F(CudnnConvRewriterTest, BackwardInputConvolveConstantFilter) { + Array4D constant_arr(4, 4, 2, 2); + constant_arr.FillIota(0); + string constant_str = + LiteralUtil::CreateR4FromArray4D(constant_arr).ToString(); + ParseAndVerifyModule(absl::StrFormat(R"( + HloModule test + + ENTRY entry_computation { + param0 = f32[128,2,16,16]{3,2,1,0} parameter(0) + constant = f32[4,4,2,2]{3,2,1,0} constant(%s) + ROOT convolution = f32[128,2,32,32]{3,2,1,0} convolution(param0, constant), + window={size=4x4 pad=2_2x2_2 lhs_dilate=2x2}, + dim_labels=bf01_01oi->bf01, feature_group_count=1 + })", + constant_str)); + EXPECT_TRUE(RunPass(&module())); + EXPECT_THAT( + module().entry_computation()->root_instruction(), + op::GetTupleElement(op::CustomCall(kCudnnConvBackwardInputCallTarget, _, + op::Reverse(op::Constant())), + 0)); +} + } // anonymous namespace } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.cc b/tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.cc new file mode 100644 index 0000000000000000000000000000000000000000..0b4fdf71623e1597168c6873a0d2b60176e518ce --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.cc @@ -0,0 +1,419 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.h" +#include "absl/strings/str_cat.h" +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/util.h" + +namespace xla { +namespace gpu { +namespace { + +using se::DeviceMemory; +using se::DeviceMemoryBase; +using se::Stream; +using se::dnn::AlgorithmConfig; +using se::dnn::BatchDescriptor; +using se::dnn::ConvolutionDescriptor; +using se::dnn::DataLayout; +using se::dnn::DimIndex; +using se::dnn::FilterDescriptor; +using se::dnn::FilterLayout; +using se::dnn::ProfileResult; + +struct CudnnConvParams { + // Here are the fields related to cuDNN's fused convolution. The result thus + // is defined as: + // activation(conv_result_scale * conv(x, w) + + // side_input_scale * side_input + broadcast(bias)) + // + // The most common fused conv is conv forward + relu/identity, for example. + // + // bias_buf is a single-dimensional array, with the length equal to the number + // of output features. It'll be broadcasted to the output shape in order to be + // added to the final results. + // + // side_input_buf, if valid, must have the same shape as the output buffer. + struct FusionParams { + se::dnn::ActivationMode mode; + double side_input_scale; + se::DeviceMemoryBase bias_buf; + se::DeviceMemoryBase side_input_buf; // nullable + }; + + CudnnConvKind kind; + const Shape* input_shape; + const Shape* filter_shape; + const Shape* output_shape; + se::DeviceMemoryBase input_buf; + se::DeviceMemoryBase filter_buf; + se::DeviceMemoryBase output_buf; + const Window* window; + const ConvolutionDimensionNumbers* dnums; + int64 feature_group_count; + se::dnn::AlgorithmConfig algorithm; + double conv_result_scale; + + absl::optional fusion; +}; + +// A StreamExecutor ScratchAllocator that wraps a single XLA allocation, +// returning it (in its entirety) the first time Allocate() is called. +class ScratchBufAllocator : public se::ScratchAllocator { + public: + explicit ScratchBufAllocator(se::DeviceMemoryBase scratch) + : scratch_(scratch) {} + + ~ScratchBufAllocator() override = default; + + int64 GetMemoryLimitInBytes(se::Stream* /*stream*/) override { + return scratch_.size(); + } + + se::port::StatusOr> AllocateBytes( + se::Stream* stream, int64 byte_size) override { + if (allocated_) { + return se::port::InternalError( + "Can't allocate twice from a ScratchBufAllocator."); + } + if (byte_size > scratch_.size()) { + return se::port::InternalError(absl::StrCat( + "Can't allocate ", byte_size, + " bytes from a ScratchBufAllocator of size ", scratch_.size())); + } + + allocated_ = true; + return se::DeviceMemory(scratch_); + } + + private: + se::DeviceMemoryBase scratch_; + bool allocated_ = false; +}; + +template +Status RunCudnnConvImpl(CudnnConvParams params, + se::ScratchAllocator* scratch_allocator, + se::Stream* stream, + se::dnn::ProfileResult* profile_result) { + CudnnConvKind kind = params.kind; + const Shape& input_shape = *params.input_shape; + const Shape& filter_shape = *params.filter_shape; + const Shape& output_shape = *params.output_shape; + DeviceMemory input_buf(params.input_buf); + DeviceMemory filter_buf(params.filter_buf); + DeviceMemory output_buf(params.output_buf); + const Window& window = *params.window; + const ConvolutionDimensionNumbers& dnums = *params.dnums; + int64 feature_group_count = params.feature_group_count; + AlgorithmConfig algorithm = params.algorithm; + + VLOG(3) << "Convolution Algorithm: " << algorithm.algorithm().algo_id(); + VLOG(3) << "tensor_ops_enabled: " + << algorithm.algorithm().tensor_ops_enabled(); + VLOG(3) << "Convolution kind: " << CudnnConvKindToString(kind); + VLOG(3) << "input shape: " << ShapeUtil::HumanStringWithLayout(input_shape); + VLOG(3) << "filter shape: " << ShapeUtil::HumanStringWithLayout(filter_shape); + VLOG(3) << "Output shape: " << ShapeUtil::HumanStringWithLayout(output_shape); + VLOG(3) << "Window: { " << window.ShortDebugString() << " }"; + VLOG(3) << "Dim nums: { " << dnums.ShortDebugString() << " }"; + + const int num_dimensions = window.dimensions_size(); + CHECK_LE(num_dimensions, 3); + // cuDNN does not support 1D convolutions. We therefore express 1D + // convolutions as 2D convolutions where the first spatial dimension is 1. + // This matches the behavior of TF (see definition of conv1d in + // tensorflow/python/ops/nn_ops.py). + const int effective_num_dimensions = std::max(2, num_dimensions); + + CHECK_EQ(primitive_util::NativeToPrimitiveType(), + output_shape.element_type()) + << ShapeUtil::HumanString(output_shape); + + CHECK_EQ(num_dimensions, dnums.input_spatial_dimensions_size()); + CHECK_EQ(num_dimensions, dnums.kernel_spatial_dimensions_size()); + CHECK_EQ(num_dimensions, dnums.output_spatial_dimensions_size()); + for (const WindowDimension& dim : window.dimensions()) { + CHECK_EQ(dim.padding_low(), dim.padding_high()); + } + + // cuDNN's convolution APIs support the BDYX layout for activations/output and + // the OIYX layout for weights. + DataLayout input_dl; + FilterLayout filter_dl; + DataLayout output_dl; + + TF_ASSIGN_OR_RETURN(std::tie(input_dl, filter_dl, output_dl), + XlaConvLayoutsToStreamExecutorLayouts( + dnums, input_shape.layout(), filter_shape.layout(), + output_shape.layout())); + + BatchDescriptor input_descriptor(effective_num_dimensions); + input_descriptor.set_layout(input_dl) + .set_feature_map_count( + input_shape.dimensions(dnums.input_feature_dimension())) + .set_count(input_shape.dimensions(dnums.input_batch_dimension())); + for (int dim = 0; dim < num_dimensions; ++dim) { + // Note that the dimensions are reversed. The same holds below. + input_descriptor.set_spatial_dim( + static_cast(effective_num_dimensions - dim - 1), + input_shape.dimensions(dnums.input_spatial_dimensions(dim))); + } + + FilterDescriptor filter_descriptor(effective_num_dimensions); + filter_descriptor.set_layout(filter_dl) + .set_input_feature_map_count( + filter_shape.dimensions(dnums.kernel_input_feature_dimension())) + .set_output_feature_map_count( + filter_shape.dimensions(dnums.kernel_output_feature_dimension())); + for (int dim = 0; dim < num_dimensions; ++dim) { + filter_descriptor.set_spatial_dim( + static_cast(effective_num_dimensions - dim - 1), + filter_shape.dimensions(dnums.kernel_spatial_dimensions(dim))); + } + + ConvolutionDescriptor convolution_descriptor(effective_num_dimensions); + convolution_descriptor.set_group_count(feature_group_count); + for (int dim = 0; dim < num_dimensions; ++dim) { + convolution_descriptor + .set_zero_padding( + static_cast(effective_num_dimensions - dim - 1), + window.dimensions(dim).padding_low()) + .set_filter_stride( + static_cast(effective_num_dimensions - dim - 1), + window.dimensions(dim).stride()); + } + + BatchDescriptor output_descriptor(effective_num_dimensions); + output_descriptor.set_layout(output_dl) + .set_feature_map_count( + output_shape.dimensions(dnums.output_feature_dimension())) + .set_count(output_shape.dimensions(dnums.output_batch_dimension())); + for (int dim = 0; dim < num_dimensions; ++dim) { + output_descriptor.set_spatial_dim( + static_cast(effective_num_dimensions - dim - 1), + output_shape.dimensions(dnums.output_spatial_dimensions(dim))); + } + + // Add a singleton dimension in the 1D convolution case. + if (num_dimensions == 1) { + input_descriptor.set_spatial_dim(static_cast(0), 1); + output_descriptor.set_spatial_dim(static_cast(0), 1); + filter_descriptor.set_spatial_dim(static_cast(0), 1); + convolution_descriptor.set_zero_padding(static_cast(0), 0) + .set_filter_stride(static_cast(0), 1); + } + + switch (kind) { + case CudnnConvKind::kForward: + if (params.conv_result_scale != 1) { + return InternalError( + "StreamExecutor doesn't support scaled convolution: %lf.", + params.conv_result_scale); + } + stream->ThenConvolveWithAlgorithm( + input_descriptor, input_buf, filter_descriptor, filter_buf, + convolution_descriptor, output_descriptor, &output_buf, + scratch_allocator, algorithm, profile_result); + break; + case CudnnConvKind::kBackwardInput: + if (params.conv_result_scale != 1) { + return InternalError( + "StreamExecutor doesn't support scaled convolution: %lf.", + params.conv_result_scale); + } + stream->ThenConvolveBackwardDataWithAlgorithm( + filter_descriptor, filter_buf, output_descriptor, output_buf, + convolution_descriptor, input_descriptor, &input_buf, + scratch_allocator, algorithm, profile_result); + break; + case CudnnConvKind::kBackwardFilter: + if (params.conv_result_scale != 1) { + return InternalError( + "StreamExecutor doesn't support scaled convolution: %lf.", + params.conv_result_scale); + } + stream->ThenConvolveBackwardFilterWithAlgorithm( + input_descriptor, input_buf, output_descriptor, output_buf, + convolution_descriptor, filter_descriptor, &filter_buf, + scratch_allocator, algorithm, profile_result); + break; + case CudnnConvKind::kForwardActivation: { + BatchDescriptor bias_desc; + bias_desc.set_count(1) + .set_height(1) + .set_width(1) + .set_feature_map_count( + output_shape.dimensions(dnums.output_feature_dimension())) + .set_layout(output_dl); + + se::DeviceMemory side_input(params.fusion->side_input_buf); + // If there is no side input, use output as the side input. + if (side_input.is_null()) { + if (params.fusion->side_input_scale != 0) { + return InternalError( + "Side input scale is not 0, yet no side input buffer is " + "provided"); + } + // Since side-input scale is 0, the values in the side input don't + // matter. The simplest thing to do would be to pass in a null buffer + // for the side input, but cudnn doesn't allow this. cudnn does promise + // that if side-input-scale is 0 the side input won't be read, so we + // just pass in the output buffer, since it's handy and has the correct + // size. + side_input = output_buf; + } + + stream->ThenFusedConvolveWithAlgorithm( + input_descriptor, input_buf, params.conv_result_scale, + filter_descriptor, filter_buf, convolution_descriptor, side_input, + params.fusion->side_input_scale, bias_desc, + DeviceMemory(params.fusion->bias_buf), params.fusion->mode, + output_descriptor, &output_buf, scratch_allocator, algorithm, + profile_result); + break; + } + } + + if (!stream->ok()) { + return InternalError( + "Unable to launch convolution with type %s and algorithm (%d, %d)", + CudnnConvKindToString(kind), algorithm.algorithm().algo_id(), + algorithm.algorithm_no_scratch().algo_id()); + } + return Status::OK(); +} + +// Returns the cudnn convolution parameters generated from conv, which must be a +// custom-call to a cudnn convolution. +StatusOr GetCudnnConvParams( + const HloCustomCallInstruction* conv, + absl::Span operand_buffers, + se::DeviceMemoryBase result_buffer) { + CudnnConvParams params; + + TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig backend_config, + conv->backend_config()); + TF_ASSIGN_OR_RETURN(CudnnConvKind kind, GetCudnnConvKind(conv)); + const auto& lhs_shape = conv->operand(0)->shape(); + const auto& rhs_shape = conv->operand(1)->shape(); + const auto& conv_result_shape = conv->shape().tuple_shapes(0); + + params.kind = kind; + params.window = &conv->window(); + params.dnums = &conv->convolution_dimension_numbers(); + params.feature_group_count = conv->feature_group_count(); + params.algorithm = se::dnn::AlgorithmConfig(se::dnn::AlgorithmDesc( + backend_config.algorithm(), backend_config.tensor_ops_enabled())); + params.conv_result_scale = backend_config.conv_result_scale(); + + switch (kind) { + case CudnnConvKind::kForward: + params.input_shape = &lhs_shape; + params.filter_shape = &rhs_shape; + params.output_shape = &conv_result_shape; + params.input_buf = operand_buffers[0]; + params.filter_buf = operand_buffers[1]; + params.output_buf = result_buffer; + break; + case CudnnConvKind::kBackwardInput: + params.input_shape = &conv_result_shape; + params.filter_shape = &rhs_shape; + params.output_shape = &lhs_shape; + params.input_buf = result_buffer; + params.filter_buf = operand_buffers[1]; + params.output_buf = operand_buffers[0]; + break; + case CudnnConvKind::kBackwardFilter: + params.input_shape = &lhs_shape; + params.filter_shape = &conv_result_shape; + params.output_shape = &rhs_shape; + params.input_buf = operand_buffers[0]; + params.filter_buf = result_buffer; + params.output_buf = operand_buffers[1]; + break; + case CudnnConvKind::kForwardActivation: { + params.kind = CudnnConvKind::kForwardActivation; + params.input_shape = &lhs_shape; + params.filter_shape = &rhs_shape; + params.output_shape = &conv_result_shape; + params.fusion.emplace(); + auto& fusion = *params.fusion; + if (backend_config.activation_mode() < + static_cast(se::dnn::ActivationMode::kNumActivationModes)) { + fusion.mode = static_cast( + backend_config.activation_mode()); + } else { + return InternalError("Bad activation mode: %s", + backend_config.ShortDebugString()); + } + fusion.side_input_scale = backend_config.side_input_scale(); + params.input_buf = operand_buffers[0]; + params.filter_buf = operand_buffers[1]; + params.output_buf = result_buffer; + params.fusion->bias_buf = operand_buffers[2]; + if (operand_buffers.size() >= 4) { + params.fusion->side_input_buf = operand_buffers[3]; + } + } + } + return params; +} + +} // anonymous namespace + +Status RunCudnnConv(const HloCustomCallInstruction* conv, + absl::Span operand_buffers, + se::DeviceMemoryBase result_buffer, + se::DeviceMemoryBase scratch_buf, se::Stream* stream, + se::dnn::ProfileResult* profile_result) { + ScratchBufAllocator scratch_allocator(scratch_buf); + return RunCudnnConv(conv, operand_buffers, result_buffer, &scratch_allocator, + stream, profile_result); +} + +Status RunCudnnConv(const HloCustomCallInstruction* conv, + absl::Span operand_buffers, + se::DeviceMemoryBase result_buffer, + se::ScratchAllocator* scratch_allocator, se::Stream* stream, + se::dnn::ProfileResult* profile_result) { + TF_ASSIGN_OR_RETURN(CudnnConvParams params, + GetCudnnConvParams(conv, operand_buffers, result_buffer)); + + PrimitiveType output_primitive_type = + conv->shape().tuple_shapes(0).element_type(); + switch (output_primitive_type) { + case F16: + return RunCudnnConvImpl(params, scratch_allocator, stream, + profile_result); + case F32: + return RunCudnnConvImpl(params, scratch_allocator, stream, + profile_result); + case F64: + return RunCudnnConvImpl(params, scratch_allocator, stream, + profile_result); + default: + LOG(FATAL) << ShapeUtil::HumanString(*params.output_shape); + } +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.h b/tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.h new file mode 100644 index 0000000000000000000000000000000000000000..edbc75a94a1238540390b93f0fa5217852c7781f --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.h @@ -0,0 +1,60 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONV_RUNNER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONV_RUNNER_H_ + +#include "absl/types/optional.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" +#include "tensorflow/compiler/xla/status.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace xla { +namespace gpu { + +// This file contains low-level routines for running cudnn convolutions. + +// Calls into cudnn to run the specified convolution. +// +// We provide one overload which takes a scratch buffer, and another which takes +// an allocator which is responsible for allocating the scratch space. In +// theory the second one shouldn't be necessary -- users of this function could +// just ask cudnn how much scratch space it needs for a particular convolution. +// But in practice, StreamExecutor does not expose such an API, and in the name +// of parsimony, perhaps it's better not to add it. Instead, the first time you +// call a convolution, you should call the version that takes a scratch +// allocator and take note of how much memory is used. The next time you call +// the same conv, you can provide an explicitly preallocated scratch buffer of +// that size, if you like. +Status RunCudnnConv(const HloCustomCallInstruction* conv, + absl::Span operand_buffers, + se::DeviceMemoryBase result_buffer, + se::DeviceMemoryBase scratch_buf, se::Stream* stream, + se::dnn::ProfileResult* profile_result = nullptr); + +Status RunCudnnConv(const HloCustomCallInstruction* conv, + absl::Span operand_buffers, + se::DeviceMemoryBase result_buffer, + se::ScratchAllocator* scratch_allocator, se::Stream* stream, + se::dnn::ProfileResult* profile_result = nullptr); + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONV_RUNNER_H_ diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc deleted file mode 100644 index 68086c86e9ba3860a0c1516c04759754513bfacb..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc +++ /dev/null @@ -1,272 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h" -#include "absl/strings/str_cat.h" -#include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/compiler/xla/util.h" - -namespace xla { -namespace gpu { -namespace { - -using se::DeviceMemory; -using se::DeviceMemoryBase; -using se::Stream; -using se::dnn::AlgorithmConfig; -using se::dnn::BatchDescriptor; -using se::dnn::ConvolutionDescriptor; -using se::dnn::DataLayout; -using se::dnn::DimIndex; -using se::dnn::FilterDescriptor; -using se::dnn::FilterLayout; -using se::dnn::ProfileResult; - -// A StreamExecutor ScratchAllocator that wraps a single XLA allocation, -// returning it (in its entirety) the first time Allocate() is called. -class ScratchBufAllocator : public se::ScratchAllocator { - public: - explicit ScratchBufAllocator(se::DeviceMemoryBase scratch) - : scratch_(scratch) {} - - ~ScratchBufAllocator() override = default; - - int64 GetMemoryLimitInBytes(se::Stream* /*stream*/) override { - return scratch_.size(); - } - - se::port::StatusOr> AllocateBytes( - se::Stream* stream, int64 byte_size) override { - if (allocated_) { - return se::port::InternalError( - "Can't allocate twice from a ScratchBufAllocator."); - } - if (byte_size > scratch_.size()) { - return se::port::InternalError(absl::StrCat( - "Can't allocate ", byte_size, - " bytes from a ScratchBufAllocator of size ", scratch_.size())); - } - - allocated_ = true; - return se::DeviceMemory(scratch_); - } - - private: - se::DeviceMemoryBase scratch_; - bool allocated_ = false; -}; - -template -Status RunCudnnConvolution( - CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape, - const Shape& output_shape, DeviceMemory input_buf, - DeviceMemory filter_buf, DeviceMemory output_buf, - se::ScratchAllocator* scratch_allocator, const Window& window, - const ConvolutionDimensionNumbers& dnums, AlgorithmConfig algorithm, - Stream* stream, ProfileResult* profile_result /*= nullptr*/) { - VLOG(3) << "Convolution Algorithm: " << algorithm.algorithm().algo_id(); - VLOG(3) << "tensor_ops_enabled: " - << algorithm.algorithm().tensor_ops_enabled(); - VLOG(3) << "Convolution kind: " << CudnnConvKindToString(kind); - VLOG(3) << "input shape: { " << ShapeUtil::HumanString(input_shape) << " }"; - VLOG(3) << "filter shape: { " << ShapeUtil::HumanString(filter_shape) << " }"; - VLOG(3) << "Output shape: { " << ShapeUtil::HumanString(output_shape) << " }"; - VLOG(3) << "Window: { " << window.ShortDebugString() << " }"; - VLOG(3) << "Dim nums: { " << dnums.ShortDebugString() << " }"; - - const int num_dimensions = window.dimensions_size(); - CHECK_LE(num_dimensions, 3); - // cuDNN does not support 1D convolutions. We therefore express 1D - // convolutions as 2D convolutions where the first spatial dimension is 1. - // This matches the behavior of TF (see definition of conv1d in - // tensorflow/python/ops/nn_ops.py). - const int effective_num_dimensions = std::max(2, num_dimensions); - - CHECK_EQ(primitive_util::NativeToPrimitiveType(), - output_shape.element_type()) - << ShapeUtil::HumanString(output_shape); - - CHECK_EQ(num_dimensions, dnums.input_spatial_dimensions_size()); - CHECK_EQ(num_dimensions, dnums.kernel_spatial_dimensions_size()); - CHECK_EQ(num_dimensions, dnums.output_spatial_dimensions_size()); - for (const WindowDimension& dim : window.dimensions()) { - CHECK_EQ(dim.padding_low(), dim.padding_high()); - } - - // cuDNN's convolution APIs support the BDYX layout for activations/output and - // the OIYX layout for weights. - DataLayout input_dl; - FilterLayout filter_dl; - DataLayout output_dl; - - TF_ASSIGN_OR_RETURN(std::tie(input_dl, filter_dl, output_dl), - XlaConvLayoutsToStreamExecutorLayouts( - dnums, input_shape.layout(), filter_shape.layout(), - output_shape.layout())); - - BatchDescriptor input_descriptor(effective_num_dimensions); - input_descriptor.set_layout(input_dl) - .set_feature_map_count( - input_shape.dimensions(dnums.input_feature_dimension())) - .set_count(input_shape.dimensions(dnums.input_batch_dimension())); - for (int dim = 0; dim < num_dimensions; ++dim) { - // Note that the dimensions are reversed. The same holds below. - input_descriptor.set_spatial_dim( - static_cast(effective_num_dimensions - dim - 1), - input_shape.dimensions(dnums.input_spatial_dimensions(dim))); - } - - FilterDescriptor filter_descriptor(effective_num_dimensions); - filter_descriptor.set_layout(filter_dl) - .set_input_feature_map_count( - filter_shape.dimensions(dnums.kernel_input_feature_dimension())) - .set_output_feature_map_count( - filter_shape.dimensions(dnums.kernel_output_feature_dimension())); - for (int dim = 0; dim < num_dimensions; ++dim) { - filter_descriptor.set_spatial_dim( - static_cast(effective_num_dimensions - dim - 1), - filter_shape.dimensions(dnums.kernel_spatial_dimensions(dim))); - } - - ConvolutionDescriptor convolution_descriptor(effective_num_dimensions); - for (int dim = 0; dim < num_dimensions; ++dim) { - convolution_descriptor - .set_zero_padding( - static_cast(effective_num_dimensions - dim - 1), - window.dimensions(dim).padding_low()) - .set_filter_stride( - static_cast(effective_num_dimensions - dim - 1), - window.dimensions(dim).stride()); - } - - BatchDescriptor output_descriptor(effective_num_dimensions); - output_descriptor.set_layout(output_dl) - .set_feature_map_count( - output_shape.dimensions(dnums.output_feature_dimension())) - .set_count(output_shape.dimensions(dnums.output_batch_dimension())); - for (int dim = 0; dim < num_dimensions; ++dim) { - output_descriptor.set_spatial_dim( - static_cast(effective_num_dimensions - dim - 1), - output_shape.dimensions(dnums.output_spatial_dimensions(dim))); - } - - // Add a singleton dimension in the 1D convolution case. - if (num_dimensions == 1) { - input_descriptor.set_spatial_dim(static_cast(0), 1); - output_descriptor.set_spatial_dim(static_cast(0), 1); - filter_descriptor.set_spatial_dim(static_cast(0), 1); - convolution_descriptor.set_zero_padding(static_cast(0), 0) - .set_filter_stride(static_cast(0), 1); - } - - switch (kind) { - case CudnnConvKind::kForward: - stream->ThenConvolveWithAlgorithm( - input_descriptor, input_buf, filter_descriptor, filter_buf, - convolution_descriptor, output_descriptor, &output_buf, - scratch_allocator, algorithm, profile_result); - break; - case CudnnConvKind::kBackwardInput: - stream->ThenConvolveBackwardDataWithAlgorithm( - filter_descriptor, filter_buf, output_descriptor, output_buf, - convolution_descriptor, input_descriptor, &input_buf, - scratch_allocator, algorithm, profile_result); - break; - case CudnnConvKind::kBackwardFilter: - stream->ThenConvolveBackwardFilterWithAlgorithm( - input_descriptor, input_buf, output_descriptor, output_buf, - convolution_descriptor, filter_descriptor, &filter_buf, - scratch_allocator, algorithm, profile_result); - break; - } - - if (!stream->ok()) { - return InternalError( - "Unable to launch convolution with type %s and algorithm (%lld, %lld)", - CudnnConvKindToString(kind).c_str(), algorithm.algorithm().algo_id(), - algorithm.algorithm_no_scratch().algo_id()); - } - return Status::OK(); -} - -} // anonymous namespace - -string CudnnConvKindToString(CudnnConvKind kind) { - switch (kind) { - case CudnnConvKind::kForward: - return "forward"; - case CudnnConvKind::kBackwardFilter: - return "backward_filter"; - case CudnnConvKind::kBackwardInput: - return "backward_input"; - } -} - -Status RunCudnnConvolution( - CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape, - const Shape& output_shape, se::DeviceMemoryBase input_buf, - se::DeviceMemoryBase filter_buf, se::DeviceMemoryBase output_buf, - se::DeviceMemoryBase scratch_buf, const Window& window, - const ConvolutionDimensionNumbers& dnums, - se::dnn::AlgorithmConfig algorithm, se::Stream* stream, - se::dnn::ProfileResult* profile_result) { - ScratchBufAllocator scratch_allocator(scratch_buf); - return RunCudnnConvolution(kind, input_shape, filter_shape, output_shape, - input_buf, filter_buf, output_buf, - &scratch_allocator, window, dnums, algorithm, - stream, profile_result); -} - -Status RunCudnnConvolution( - CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape, - const Shape& output_shape, se::DeviceMemoryBase input_buf, - se::DeviceMemoryBase filter_buf, se::DeviceMemoryBase output_buf, - se::ScratchAllocator* scratch_allocator, const Window& window, - const ConvolutionDimensionNumbers& dnums, - se::dnn::AlgorithmConfig algorithm, se::Stream* stream, - se::dnn::ProfileResult* profile_result) { - PrimitiveType output_primitive_type = output_shape.element_type(); - switch (output_primitive_type) { - case F16: - return RunCudnnConvolution(kind, input_shape, filter_shape, output_shape, - se::DeviceMemory(input_buf), - se::DeviceMemory(filter_buf), - se::DeviceMemory(output_buf), - scratch_allocator, window, dnums, algorithm, - stream, profile_result); - case F32: - return RunCudnnConvolution(kind, input_shape, filter_shape, output_shape, - se::DeviceMemory(input_buf), - se::DeviceMemory(filter_buf), - se::DeviceMemory(output_buf), - scratch_allocator, window, dnums, algorithm, - stream, profile_result); - case F64: - return RunCudnnConvolution(kind, input_shape, filter_shape, output_shape, - se::DeviceMemory(input_buf), - se::DeviceMemory(filter_buf), - se::DeviceMemory(output_buf), - scratch_allocator, window, dnums, algorithm, - stream, profile_result); - default: - LOG(FATAL) << ShapeUtil::HumanString(output_shape); - } -} - -} // namespace gpu -} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h deleted file mode 100644 index 944e4ac686d45408b08ff1faa321510c1c8920ba..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h +++ /dev/null @@ -1,94 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_RUNNER_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_RUNNER_H_ - -#include "tensorflow/compiler/xla/status.h" -#include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/platform/stream_executor_no_cuda.h" - -namespace xla { -namespace gpu { - -// This file contains low-level routines for running cudnn convolutions. - -// Different types of convolutions supported by cudnn. -// -// A way to think about these is that a convolution is defined by three arrays -// -- the "input", the "filter", and the "output" -- and given any two of these, -// we can compute the third. For example, a backward-input convolution takes as -// input a filter and an "output" and produces an "input" such that if one were -// to do a forward convolution of "input" using filter, the result would be -// something with the same shape as "output". -// -// This way of thinking is not correct if you look at the values produced. For -// example, a backward-input convolution is not actually the mathematical -// inverse of a forward convolution. But it's right as far as the shapes and -// "connectivity" (i.e. which elements of the input affect which elements of -// the output) are concerned. -enum class CudnnConvKind { - kForward, // input + filter => output - kBackwardInput, // filter + output => input - kBackwardFilter, // input + output => filter -}; - -// Converts a CudnnConvKind value to a string. -string CudnnConvKindToString(CudnnConvKind kind); - -// Calls into cudnn to run the specified convolution. -// -// Note that depending on the value of CudnnConvKind, the result of this call -// may be written into input_buf, filter_buf, or output_buf! -// -// At the moment we only support cudnn convolutions over float and half, and -// convolution with half data type is implemented with cudnn PSEUDO_HALF -// configuration, that is, the input values are half and the internal -// computation type is float. -// -// We provide one overload which takes a scratch buffer, and another which takes -// an allocator which is responsible for allocating the scratch space. In -// theory the second one shouldn't be necessary -- users of this function could -// just ask cudnn how much scratch space it needs for a particular convolution. -// But in practice, StreamExecutor does not expose such an API, and in the name -// of parsimony, perhaps it's better not to add it. Instead, the first time you -// call a convolution, you should call the version that takes a scratch -// allocator and take note of how much memory is used. The next time you call -// the same conv, you can provide an explicitly preallocated scratch buffer of -// that size, if you like. -Status RunCudnnConvolution( - CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape, - const Shape& output_shape, se::DeviceMemoryBase input_buf, - se::DeviceMemoryBase filter_buf, se::DeviceMemoryBase output_buf, - se::DeviceMemoryBase scratch_buf, const Window& window, - const ConvolutionDimensionNumbers& dnums, - se::dnn::AlgorithmConfig algorithm, se::Stream* stream, - se::dnn::ProfileResult* profile_result = nullptr); - -Status RunCudnnConvolution( - CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape, - const Shape& output_shape, se::DeviceMemoryBase input_buf, - se::DeviceMemoryBase filter_buf, se::DeviceMemoryBase output_buf, - se::ScratchAllocator* scratch_allocator, const Window& window, - const ConvolutionDimensionNumbers& dnums, - se::dnn::AlgorithmConfig algorithm, se::Stream* stream, - se::dnn::ProfileResult* profile_result = nullptr); - -} // namespace gpu -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_RUNNER_H_ diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.cc b/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.cc new file mode 100644 index 0000000000000000000000000000000000000000..8ac11bcf657db4eab76c611b8975e12e190994c5 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.cc @@ -0,0 +1,279 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.h" + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/service/pattern_matcher.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace xla { +namespace gpu { +namespace { + +// Describes a matched pattern: +// max(0, alpha1 * conv(x, w) + alpha2 * side_input + broadcast(bias)); +// Where side_input has the shape of output buffer, and bias is a 1D array with +// the dimension of number of output features. +struct ConvWithRelu { + HloInstruction* maximum; + HloCustomCallInstruction* conv; + HloInstruction* bias; + HloInstruction* side_input; + HloConstantInstruction* alpha_conv; + HloConstantInstruction* alpha_side_input; +}; + +absl::optional FindConvWithRelu(HloInstruction* instr) { + using match::Add; + using match::AddAnyOrder; + using match::AnyOf; + using match::Broadcast; + using match::Constant; + using match::GetTupleElement; + using match::Maximum; + using match::MultiplyAnyOrder; + using match::Op; + + // The pattern we want to match: + // max(0, alpha1 * conv(x, w) + alpha2 * side_input + broadcast(bias)); + // + // With its variants involving commute/reassociation of adds, multiplies, and + // max, and omission of alpha1, side_input, alpha2, or bias. + + HloInstruction* relu_input; + + // Match max(0, relu_input). + auto zero_pattern = Broadcast(match::ConstantScalar(0)); + if (!Match(instr, Maximum(zero_pattern, Op(&relu_input))) && + !Match(instr, Maximum(Op(&relu_input), zero_pattern))) { + return absl::nullopt; + } + HloInstruction* conv_instr = nullptr; + HloInstruction* alpha_conv_instr = nullptr; + HloInstruction* alpha_side_input_instr = nullptr; + HloInstruction* bias_broadcast_instr = nullptr; + HloInstruction* bias = nullptr; + HloInstruction* side_input = nullptr; + + // These nodes will not be in the returned value, but we need to check them + // for single use. + HloInstruction *gte = nullptr, *add1 = nullptr, *add2 = nullptr, + *mul1 = nullptr, *mul2 = nullptr; + + const auto bias_pattern = Broadcast(&bias_broadcast_instr, Op(&bias)); + const auto conv_pattern = [&] { + auto alpha_pattern = Broadcast(Constant(&alpha_conv_instr)); + auto conv_pattern = GetTupleElement( + >e, Op(&conv_instr).WithOpcode(HloOpcode::kCustomCall), 0); + return AnyOf( + MultiplyAnyOrder(&mul1, alpha_pattern, conv_pattern), conv_pattern); + }(); + const auto side_input_pattern = [&] { + auto alpha_pattern = Broadcast(Constant(&alpha_side_input_instr)); + // If bias is already matched, match arbitrary additional input as side + // input. Note this may force a cheap operation (e.g. broadcast) to be + // materialized into a large buffer, as large as the output buffer. + // + // TODO(timshen): If in practice there are significant false positives, we + // should fix it. + auto side_input_pattern = Op(&side_input); + return AnyOf( + MultiplyAnyOrder(&mul2, alpha_pattern, side_input_pattern), + side_input_pattern); + }(); + + { + // Try to match any of the following form of add, in any association: + // addends[0] + // addends[0] + addends[1] + // addends[0] + addends[1] + addends[2] + // + // Then try to match each addend with one of the three patterns: bias, conv, + // or side_input. Notice that side_input matching must go last, as it + // also matches a conv or a bias. + HloInstruction* addends[3] = {nullptr, nullptr, nullptr}; + auto add3_pattern = [&] { + auto add2_pattern = Add(&add1, Op(&addends[0]), Op(&addends[1])); + return AnyOf( + AddAnyOrder(&add2, add2_pattern, Op(&addends[2])), add2_pattern, + Op(&addends[0])); + }(); + CHECK(Match(relu_input, add3_pattern)); + for (auto addend : addends) { + if (addend) { + if (bias == nullptr && Match(addend, bias_pattern)) { + CHECK(bias); + } else if (conv_instr == nullptr && Match(addend, conv_pattern)) { + CHECK(conv_instr); + } else if (side_input == nullptr && Match(addend, side_input_pattern)) { + CHECK(side_input); + } else { + return absl::nullopt; + } + } + } + } + + if (conv_instr == nullptr) { + return absl::nullopt; + } + + for (HloInstruction* instr : + {conv_instr, bias_broadcast_instr, gte, add1, add2, mul1, mul2}) { + if (instr && instr->user_count() > 1) { + return absl::nullopt; + } + } + + auto conv = Cast(conv_instr); + auto bias_broadcast = + CastOrNull(bias_broadcast_instr); + + if (conv->custom_call_target() != kCudnnConvForwardCallTarget) { + return absl::nullopt; + } + + if (bias_broadcast) { + // TODO(timshen): handle bias_broadcast_instr->dimensions() == {}. + if (bias_broadcast_instr->dimensions().size() != 1) { + return absl::nullopt; + } + if (bias_broadcast_instr->dimensions(0) != + conv->convolution_dimension_numbers().output_feature_dimension()) { + return absl::nullopt; + } + } + + return ConvWithRelu{ + instr, + conv, + bias, + side_input, + CastOrNull(alpha_conv_instr), + CastOrNull(alpha_side_input_instr)}; +} + +StatusOr> TryRewriteToCudnnForwardRelu( + ConvWithRelu match) { + auto conv = match.conv; + + HloComputation* computation = conv->parent(); + PrimitiveType element_type = conv->operand(0)->shape().element_type(); + + const auto get_alpha_value = + [](HloConstantInstruction* instr) -> StatusOr { + TF_ASSIGN_OR_RETURN( + auto alpha, + Cast(instr)->literal().Convert(F64)); + return alpha.GetFirstElement(); + }; + + double alpha_conv = 1; + if (match.alpha_conv) { + TF_ASSIGN_OR_RETURN(alpha_conv, get_alpha_value(match.alpha_conv)); + } + + double alpha_side_input; + if (match.side_input) { + if (match.alpha_side_input) { + TF_ASSIGN_OR_RETURN(alpha_side_input, + get_alpha_value(match.alpha_side_input)); + } else { + alpha_side_input = 1; + } + } else { + CHECK(match.alpha_side_input == nullptr); + alpha_side_input = 0; + } + + auto bias = match.bias; + if (!bias) { + auto zero = computation->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::Zero(element_type))); + + int64 num_output_feature = conv->shape().tuple_shapes(0).dimensions( + conv->convolution_dimension_numbers().output_feature_dimension()); + bias = computation->AddInstruction(HloInstruction::CreateBroadcast( + ShapeUtil::MakeShapeWithDescendingLayout(element_type, + {num_output_feature}), + zero, {})); + } + + CHECK(bias); + std::vector args = {conv->mutable_operand(0), + conv->mutable_operand(1), bias}; + if (match.side_input) { + args.push_back(match.side_input); + } + auto new_conv = computation->AddInstruction(HloInstruction::CreateCustomCall( + conv->shape(), args, kCudnnConvBiasActivationForwardCallTarget)); + new_conv->set_window(conv->window()); + new_conv->set_convolution_dimension_numbers( + conv->convolution_dimension_numbers()); + TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig config, + conv->backend_config()); + config.set_activation_mode( + static_cast(se::dnn::ActivationMode::kRelu)); + config.set_conv_result_scale(alpha_conv); + config.set_side_input_scale(alpha_side_input); + TF_RETURN_IF_ERROR(new_conv->set_backend_config(config)); + + VLOG(1) << "Replacing convolution " << conv->ToString() << " with " + << new_conv->ToString(); + return HloInstruction::CreateGetTupleElement(conv->shape().tuple_shapes(0), + new_conv, 0); +} + +} // namespace + +StatusOr CudnnFusedConvRewriter::Run(HloModule* module) { + bool changed = false; + for (HloComputation* computation : module->MakeNonfusionComputations()) { + std::vector matches; + int num_forward_convs = 0; + for (auto instr : computation->instructions()) { + auto match = FindConvWithRelu(instr); + if (match.has_value()) { + matches.push_back(*match); + } + if (auto call = DynCast(instr)) { + if (call->custom_call_target() == kCudnnConvForwardCallTarget) { + num_forward_convs++; + } + } + } + VLOG(1) << "Identified cuDNN forward conv + relu: " << matches.size() + << " out of " << num_forward_convs << " forward convs."; + std::vector>> + replacements; + for (const ConvWithRelu& match : matches) { + TF_ASSIGN_OR_RETURN(auto new_instr, TryRewriteToCudnnForwardRelu(match)); + replacements.push_back({match.maximum, std::move(new_instr)}); + changed = true; + } + for (auto& replacement : replacements) { + TF_RETURN_IF_ERROR(computation->ReplaceWithNewInstruction( + replacement.first, std::move(replacement.second))); + } + } + return changed; +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/gpu_options.h b/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.h similarity index 57% rename from tensorflow/compiler/xla/service/gpu/gpu_options.h rename to tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.h index 498d4a94955cb2c50e0b165f28ded44ac1c0bfff..613ed8dbdc33dfc3684deb5fd3ee8f5b9ea5fc50 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_options.h +++ b/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.h @@ -13,21 +13,25 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_OPTIONS_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_OPTIONS_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_FUSED_CONV_REWRITER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_FUSED_CONV_REWRITER_H_ -#include "tensorflow/compiler/xla/service/hlo_module_config.h" - -// Helper functions for querying options that are specific to the GPU backend. +#include "tensorflow/compiler/xla/service/hlo_instructions.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" namespace xla { namespace gpu { -// Returns true if we should use heuristics to assign convolution layouts, as -// opposed to always assigning NCHW. -bool ConvUseLayoutHeuristic(const HloModuleConfig& config); +class CudnnFusedConvRewriter : public HloModulePass { + public: + absl::string_view name() const override { + return "cudnn-fused-convolution-rewriter"; + } + + StatusOr Run(HloModule* module) override; +}; } // namespace gpu } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_OPTIONS_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_FUSED_CONV_REWRITER_H_ diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc index 2460d951bd7c5aa50b4d79791effa567a9103fcd..6dcdaf1cfe06e446deed847aaf29088a7ed10e13 100644 --- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc @@ -74,10 +74,8 @@ GpuElementalIrEmitter::GpuElementalIrEmitter( compute_nested_(std::move(compute_nested)) {} StatusOr GpuElementalIrEmitter::EmitLibdeviceMathCall( - const string& callee_name, - tensorflow::gtl::ArraySlice operands, - tensorflow::gtl::ArraySlice input_types, - PrimitiveType output_type) const { + const string& callee_name, absl::Span operands, + absl::Span input_types, PrimitiveType output_type) { // The libdevice math functions differentiate between "double" and "float" by // appending an 'f' to the function's name. libdevice doesn't have f16 math // functions, so we convert the operands to f32 before calling the function @@ -94,7 +92,7 @@ StatusOr GpuElementalIrEmitter::EmitLibdeviceMathCall( for (int64 i = 0; i < operands.size(); ++i) { if (input_types[i] == F16) { converted_operands[i] = - b_->CreateFPCast(converted_operands[i], b_->getFloatTy()); + FPCast(converted_operands[i], b_->getFloatTy()); converted_input_types[i] = F32; } } @@ -107,22 +105,20 @@ StatusOr GpuElementalIrEmitter::EmitLibdeviceMathCall( break; default: return Unimplemented("Bad type for libdevice math call: %s", - PrimitiveType_Name(output_type).c_str()); + PrimitiveType_Name(output_type)); } llvm::Value* result = EmitMathCall(munged_callee, converted_operands, converted_input_types, output_type) .ValueOrDie(); if (cast_result_to_fp16) { - result = b_->CreateFPCast(result, b_->getHalfTy()); + result = FPCast(result, b_->getHalfTy()); } return result; } StatusOr GpuElementalIrEmitter::EmitLlvmIntrinsicMathCall( - const string& callee_name, - tensorflow::gtl::ArraySlice operands, - tensorflow::gtl::ArraySlice input_types, - PrimitiveType output_type) const { + const string& callee_name, absl::Span operands, + absl::Span input_types, PrimitiveType output_type) { // llvm intrinsics differentiate between half/float/double functions via // the suffixes ".f16", ".f32" and ".f64". string munged_callee = callee_name; @@ -138,22 +134,20 @@ StatusOr GpuElementalIrEmitter::EmitLlvmIntrinsicMathCall( break; default: return Unimplemented("Bad type for llvm intrinsic math call: %s", - PrimitiveType_Name(output_type).c_str()); + PrimitiveType_Name(output_type)); } return EmitMathCall(munged_callee, operands, input_types, output_type); } StatusOr GpuElementalIrEmitter::EmitMathCall( - const string& callee_name, - tensorflow::gtl::ArraySlice operands, - tensorflow::gtl::ArraySlice input_types, - PrimitiveType output_type) const { + const string& callee_name, absl::Span operands, + absl::Span input_types, PrimitiveType output_type) { // Binary math functions transform are of type [T] -> T. for (PrimitiveType input_type : input_types) { if (output_type != input_type) { return Unimplemented("Input type ≠ output type: %s ≠ %s", - PrimitiveType_Name(input_type).c_str(), - PrimitiveType_Name(output_type).c_str()); + PrimitiveType_Name(input_type), + PrimitiveType_Name(output_type)); } } @@ -163,8 +157,7 @@ StatusOr GpuElementalIrEmitter::EmitMathCall( } StatusOr GpuElementalIrEmitter::EmitFloatBinaryOp( - const HloInstruction* op, llvm::Value* lhs_value, - llvm::Value* rhs_value) const { + const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) { PrimitiveType lhs_input_type = op->operand(0)->shape().element_type(); PrimitiveType rhs_input_type = op->operand(1)->shape().element_type(); PrimitiveType output_type = op->shape().element_type(); @@ -183,8 +176,7 @@ StatusOr GpuElementalIrEmitter::EmitFloatBinaryOp( } StatusOr GpuElementalIrEmitter::EmitPowerOp( - const HloInstruction* op, llvm::Value* lhs_value, - llvm::Value* rhs_value) const { + const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) { CHECK_EQ(op->opcode(), HloOpcode::kPower); PrimitiveType lhs_input_type = op->operand(0)->shape().element_type(); PrimitiveType rhs_input_type = op->operand(1)->shape().element_type(); @@ -218,7 +210,7 @@ StatusOr GpuElementalIrEmitter::EmitPowerOp( // TODO(jlebar): Does this happen with fastmath disabled? If not, should // we force-enable it? TF_ASSIGN_OR_RETURN(auto* sqrt, make_sqrt()); - return b_->CreateFDiv(llvm::ConstantFP::get(llvm_ty, 1), sqrt); + return FDiv(llvm::ConstantFP::get(llvm_ty, 1), sqrt); } VLOG(10) << "emitting pow as regular call to pow(): " << op->ToString(); @@ -227,55 +219,56 @@ StatusOr GpuElementalIrEmitter::EmitPowerOp( } StatusOr GpuElementalIrEmitter::EmitErfcInv( - PrimitiveType prim_type, llvm::Value* value) const { + PrimitiveType prim_type, llvm::Value* value) { return EmitLibdeviceMathCall("__nv_erfcinv", {value}, {prim_type}, prim_type); } -StatusOr GpuElementalIrEmitter::EmitLog( - PrimitiveType prim_type, llvm::Value* value) const { +StatusOr GpuElementalIrEmitter::EmitLog(PrimitiveType prim_type, + llvm::Value* value) { return EmitLibdeviceMathCall("__nv_log", {value}, {prim_type}, prim_type); } -StatusOr GpuElementalIrEmitter::EmitLog1p( - PrimitiveType prim_type, llvm::Value* value) const { +StatusOr GpuElementalIrEmitter::EmitLog1p(PrimitiveType prim_type, + llvm::Value* value) { return EmitLibdeviceMathCall("__nv_log1p", {value}, {prim_type}, prim_type); } -StatusOr GpuElementalIrEmitter::EmitSin( - PrimitiveType prim_type, llvm::Value* value) const { +StatusOr GpuElementalIrEmitter::EmitSin(PrimitiveType prim_type, + llvm::Value* value) { return EmitLibdeviceMathCall("__nv_sin", {value}, {prim_type}, prim_type); } -StatusOr GpuElementalIrEmitter::EmitCos( - PrimitiveType prim_type, llvm::Value* value) const { +StatusOr GpuElementalIrEmitter::EmitCos(PrimitiveType prim_type, + llvm::Value* value) { return EmitLibdeviceMathCall("__nv_cos", {value}, {prim_type}, prim_type); } -StatusOr GpuElementalIrEmitter::EmitExp( - PrimitiveType prim_type, llvm::Value* value) const { +StatusOr GpuElementalIrEmitter::EmitExp(PrimitiveType prim_type, + llvm::Value* value) { return EmitLibdeviceMathCall("__nv_exp", {value}, {prim_type}, prim_type); } -StatusOr GpuElementalIrEmitter::EmitExpm1( - PrimitiveType prim_type, llvm::Value* value) const { +StatusOr GpuElementalIrEmitter::EmitExpm1(PrimitiveType prim_type, + llvm::Value* value) { return EmitLibdeviceMathCall("__nv_expm1", {value}, {prim_type}, prim_type); } StatusOr GpuElementalIrEmitter::EmitPow(PrimitiveType prim_type, llvm::Value* lhs, - llvm::Value* rhs) const { + llvm::Value* rhs) { return EmitLibdeviceMathCall("__nv_pow", {lhs, rhs}, {prim_type, prim_type}, prim_type); } -StatusOr GpuElementalIrEmitter::EmitAtan2( - PrimitiveType prim_type, llvm::Value* lhs, llvm::Value* rhs) const { +StatusOr GpuElementalIrEmitter::EmitAtan2(PrimitiveType prim_type, + llvm::Value* lhs, + llvm::Value* rhs) { return EmitLibdeviceMathCall("__nv_atan2", {lhs, rhs}, {prim_type, prim_type}, prim_type); } -StatusOr GpuElementalIrEmitter::EmitTanh( - PrimitiveType prim_type, llvm::Value* value) const { +StatusOr GpuElementalIrEmitter::EmitTanh(PrimitiveType prim_type, + llvm::Value* value) { // Emit a fast approximation of tanh instead of calling __nv_tanh. // __nv_tanh is particularly bad because it contains branches, thus // preventing LLVM's load-store vectorizer from working its magic across a @@ -285,17 +278,15 @@ StatusOr GpuElementalIrEmitter::EmitTanh( // Upcast F16 to F32 if necessary. llvm::Type* type = prim_type == F16 ? b_->getFloatTy() : value->getType(); - llvm::Value* input = b_->CreateFPCast(value, type); + llvm::Value* input = FPCast(value, type); llvm::Value* fast_tanh = llvm_ir::EmitFastTanh(b_, input); - return b_->CreateFPCast(fast_tanh, value->getType()); + return FPCast(fast_tanh, value->getType()); } llvm::Value* GpuElementalIrEmitter::EmitDeviceFunctionCall( - const string& callee_name, - tensorflow::gtl::ArraySlice operands, - tensorflow::gtl::ArraySlice input_types, - PrimitiveType output_type, - tensorflow::gtl::ArraySlice attributes) const { + const string& callee_name, absl::Span operands, + absl::Span input_types, PrimitiveType output_type, + absl::Span attributes) { std::vector ir_input_types; for (PrimitiveType input_type : input_types) { ir_input_types.push_back( @@ -315,29 +306,28 @@ llvm::Value* GpuElementalIrEmitter::EmitDeviceFunctionCall( callee->addFnAttr(attribute); } - return b_->CreateCall(callee, llvm_ir::AsArrayRef(operands)); + return Call(callee, llvm_ir::AsArrayRef(operands)); } -llvm::Value* GpuElementalIrEmitter::EmitThreadId() const { - llvm::Value* block_id = b_->CreateIntCast( - llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x, - {}, {}, b_), - b_->getIntNTy(128), /*isSigned=*/true, "block.id"); - llvm::Value* thread_id_in_block = b_->CreateIntCast( - llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x, - {}, {}, b_), - b_->getIntNTy(128), /*isSigned=*/true, "thread.id"); - llvm::Value* threads_per_block = b_->CreateIntCast( - llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_read_ptx_sreg_ntid_x, - {}, {}, b_), - b_->getIntNTy(128), /*isSigned=*/true, "threads_per_block"); - return b_->CreateNSWAdd(b_->CreateNSWMul(block_id, threads_per_block), - thread_id_in_block); +llvm::Value* GpuElementalIrEmitter::EmitThreadId() { + llvm::Value* block_id = + IntCast(llvm_ir::EmitCallToIntrinsic( + llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x, {}, {}, b_), + b_->getIntNTy(128), /*isSigned=*/true, "block.id"); + llvm::Value* thread_id_in_block = + IntCast(llvm_ir::EmitCallToIntrinsic( + llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x, {}, {}, b_), + b_->getIntNTy(128), /*isSigned=*/true, "thread.id"); + llvm::Value* threads_per_block = + IntCast(llvm_ir::EmitCallToIntrinsic( + llvm::Intrinsic::nvvm_read_ptx_sreg_ntid_x, {}, {}, b_), + b_->getIntNTy(128), /*isSigned=*/true, "threads_per_block"); + return NSWAdd(NSWMul(block_id, threads_per_block), thread_id_in_block); } llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator( const HloInstruction* hlo, - const HloToElementGeneratorMap& operand_to_generator) const { + const HloToElementGeneratorMap& operand_to_generator) { switch (hlo->opcode()) { case HloOpcode::kMap: return [=, &operand_to_generator]( @@ -368,13 +358,6 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator( const HloInstruction* operand = hlo->operand(0); const Window& window = hlo->window(); - // TODO(b/31410564): Implement dilation for reduce-window. - if (window_util::HasDilation(window)) { - return Unimplemented( - "Dilation for reduce-window not implemented on GPU. " - "See b/31410564."); - } - PrimitiveType operand_element_type = operand->shape().element_type(); llvm::Value* accum_ptr = llvm_ir::EmitAllocaAtFunctionEntry( llvm_ir::PrimitiveTypeToIrType(operand_element_type, module_), @@ -383,7 +366,7 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator( TF_ASSIGN_OR_RETURN(llvm::Value * init_value, operand_to_generator.at(hlo->operand(1))( IrArray::Index(index.GetType()))); - b_->CreateStore(init_value, accum_ptr); + Store(init_value, accum_ptr); } llvm::Type* index_type = index.GetType(); @@ -405,22 +388,36 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator( IrArray::Index input_index(index_type, index.size()); llvm::Value* in_bounds = b_->getInt1(true); for (size_t i = 0; i < index.size(); ++i) { - llvm::Value* stridden_index = b_->CreateNSWMul( + llvm::Value* stridden_index = NSWMul( index[i], index_typed_const(window.dimensions(i).stride())); - input_index[i] = b_->CreateNSWSub( - b_->CreateNSWAdd(stridden_index, window_index[i]), + input_index[i] = NSWSub( + NSWAdd(stridden_index, + NSWMul(window_index[i], + index_typed_const( + window.dimensions(i).window_dilation()))), index_typed_const(window.dimensions(i).padding_low())); + // We need to verify that we are not in the dilated base area. + llvm::Value* dilation_condition = ICmpEQ( + SRem(input_index[i], + index_typed_const(window.dimensions(i).base_dilation())), + index_typed_const(0)); + in_bounds = And(in_bounds, dilation_condition); + + // Apply base dilation to the index. + input_index[i] = + SDiv(input_index[i], + index_typed_const(window.dimensions(i).base_dilation())); + // We must check whether 0 ≤ input_index[i] < bound, as otherwise // we are in the pad and so can skip the computation. This // comparison is equivalent to the unsigned comparison // input_index[i] < bound, as a negative value wraps to a large // positive value. - in_bounds = b_->CreateAnd( - in_bounds, - b_->CreateICmpULT( - input_index[i], - index_typed_const(operand->shape().dimensions(i)))); + in_bounds = + And(in_bounds, + ICmpULT(input_index[i], + index_typed_const(operand->shape().dimensions(i)))); } llvm_ir::LlvmIfData if_data = @@ -432,12 +429,11 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator( operand_to_generator.at(operand)(input_index)); TF_ASSIGN_OR_RETURN( llvm::Value * accum_value, - compute_nested_(*hlo->to_apply(), - {b_->CreateLoad(accum_ptr), input_value})); - b_->CreateStore(accum_value, accum_ptr); + compute_nested_(*hlo->to_apply(), {Load(accum_ptr), input_value})); + Store(accum_value, accum_ptr); SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), b_); - return b_->CreateLoad(accum_ptr); + return Load(accum_ptr); }; case HloOpcode::kReduce: // TODO(b/112040122): This should be supported. diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h index 84454d31bb820a3de6ef3364bd205b8115bd95c0..e8b56a39ce58b6aab35c1c977553c7ff7e753273 100644 --- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Value.h" #include "tensorflow/compiler/xla/service/elemental_ir_emitter.h" @@ -30,7 +31,6 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" namespace xla { namespace gpu { @@ -38,9 +38,9 @@ namespace gpu { class GpuElementalIrEmitter : public ElementalIrEmitter { public: // A NestedComputer computes an element of the output of the given computation - // given an ArraySlice of its input elements. + // given a Span of its input elements. using NestedComputer = std::function( - const HloComputation&, tensorflow::gtl::ArraySlice)>; + const HloComputation&, absl::Span)>; GpuElementalIrEmitter(const HloModuleConfig& hlo_module_config, llvm::Module* module, llvm::IRBuilder<>* b, @@ -48,85 +48,77 @@ class GpuElementalIrEmitter : public ElementalIrEmitter { llvm_ir::ElementGenerator MakeElementGenerator( const HloInstruction* hlo, - const HloToElementGeneratorMap& operand_to_generator) const override; + const HloToElementGeneratorMap& operand_to_generator) override; protected: - StatusOr EmitFloatBinaryOp( - const HloInstruction* op, llvm::Value* lhs_value, - llvm::Value* rhs_value) const override; + StatusOr EmitFloatBinaryOp(const HloInstruction* op, + llvm::Value* lhs_value, + llvm::Value* rhs_value) override; StatusOr EmitErfcInv(PrimitiveType prim_type, - llvm::Value* value) const override; + llvm::Value* value) override; StatusOr EmitLog(PrimitiveType prim_type, - llvm::Value* value) const override; + llvm::Value* value) override; StatusOr EmitLog1p(PrimitiveType prim_type, - llvm::Value* value) const override; + llvm::Value* value) override; StatusOr EmitSin(PrimitiveType prim_type, - llvm::Value* value) const override; + llvm::Value* value) override; StatusOr EmitCos(PrimitiveType prim_type, - llvm::Value* value) const override; + llvm::Value* value) override; StatusOr EmitExp(PrimitiveType prim_type, - llvm::Value* value) const override; + llvm::Value* value) override; StatusOr EmitExpm1(PrimitiveType prim_type, - llvm::Value* value) const override; + llvm::Value* value) override; StatusOr EmitPow(PrimitiveType prim_type, llvm::Value* lhs, - llvm::Value* rhs) const override; + llvm::Value* rhs) override; StatusOr EmitAtan2(PrimitiveType prim_type, llvm::Value* lhs, - llvm::Value* rhs) const override; + llvm::Value* rhs) override; StatusOr EmitTanh(PrimitiveType prim_type, - llvm::Value* value) const override; + llvm::Value* value) override; - llvm::Value* EmitThreadId() const override; + llvm::Value* EmitThreadId() override; private: // Emits IR for op, which must have opcode kPower. StatusOr EmitPowerOp(const HloInstruction* op, llvm::Value* lhs_value, - llvm::Value* rhs_value) const; + llvm::Value* rhs_value); // Emits IR to call a device function named "callee_name" on the given // operand. Returns the IR value that represents the return value. llvm::Value* EmitDeviceFunctionCall( - const string& callee_name, - tensorflow::gtl::ArraySlice operands, - tensorflow::gtl::ArraySlice input_type, - PrimitiveType output_type, - tensorflow::gtl::ArraySlice attributes) const; + const string& callee_name, absl::Span operands, + absl::Span input_type, PrimitiveType output_type, + absl::Span attributes); // Emits IR to call an LLVM intrinsic of type [T] -> T. Adjusts // callee_name according to T. Returns the IR value that represents the // return value of the function. StatusOr EmitLlvmIntrinsicMathCall( - const string& callee_name, - tensorflow::gtl::ArraySlice operands, - tensorflow::gtl::ArraySlice input_types, - PrimitiveType output_type) const; + const string& callee_name, absl::Span operands, + absl::Span input_types, PrimitiveType output_type); // Emits IR to call a libdevice function of type [T] -> T. Adjusts // callee_name according to T. Returns the IR value that represents the // return value of the function. StatusOr EmitLibdeviceMathCall( - const string& callee_name, - tensorflow::gtl::ArraySlice operands, - tensorflow::gtl::ArraySlice input_types, - PrimitiveType output_type) const; + const string& callee_name, absl::Span operands, + absl::Span input_types, PrimitiveType output_type); // Emits IR to call a function of type [T] -> T. Does not munge callee_name. // Returns the IR value that represents the return value of the function. StatusOr EmitMathCall( - const string& callee_name, - tensorflow::gtl::ArraySlice operands, - tensorflow::gtl::ArraySlice input_types, - PrimitiveType output_type) const; + const string& callee_name, absl::Span operands, + absl::Span input_types, PrimitiveType output_type); const HloModuleConfig& hlo_module_config_; NestedComputer compute_nested_; diff --git a/tensorflow/compiler/xla/service/gpu/fft_thunk.cc b/tensorflow/compiler/xla/service/gpu/fft_thunk.cc index def595d217a831b3136adbb77ff6d2897e09efd9..ca4a605af5d3b6b58b603d7ddad60ed9ae8a212f 100644 --- a/tensorflow/compiler/xla/service/gpu/fft_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/fft_thunk.cc @@ -18,10 +18,10 @@ limitations under the License. #include #include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -43,8 +43,8 @@ StatusOr> FftScratchAllocator::AllocateBytes( if (byte_size > GetMemoryLimitInBytes(stream)) { return se::port::Status( se::port::error::RESOURCE_EXHAUSTED, - tensorflow::strings::Printf( - "Allocating %lld bytes exceeds the memory limit of %lld bytes.", + absl::StrFormat( + "Allocating %d bytes exceeds the memory limit of %d bytes.", byte_size, GetMemoryLimitInBytes(stream))); } @@ -92,8 +92,7 @@ string FftTypeToString(se::fft::Type type) { } // namespace -FftThunk::FftThunk(FftType fft_type, - tensorflow::gtl::ArraySlice fft_length, +FftThunk::FftThunk(FftType fft_type, absl::Span fft_length, const BufferAllocation::Slice& input_buffer, const BufferAllocation::Slice& output_buffer, const Shape& input_shape, const Shape& output_shape, @@ -213,7 +212,7 @@ Status FftThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, return Status::OK(); } return InternalError("Unable to launch fft for thunk %p with type %s", this, - FftTypeToString(fft_type_).c_str()); + FftTypeToString(fft_type_)); } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/fft_thunk.h b/tensorflow/compiler/xla/service/gpu/fft_thunk.h index 4adec7ee54459abbbc4235550689c3cb1f7858a6..2be50e08bd2b561b44245b20e1fb200e31e65a41 100644 --- a/tensorflow/compiler/xla/service/gpu/fft_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/fft_thunk.h @@ -62,7 +62,7 @@ class FftThunk : public Thunk { public: // Constructs a thunk for launching an FFT on a stream. // Semantics of null hlo_instruction argument are as in Thunk. - FftThunk(FftType fft_type, tensorflow::gtl::ArraySlice fft_length, + FftThunk(FftType fft_type, absl::Span fft_length, const BufferAllocation::Slice& input_buffer, const BufferAllocation::Slice& output_buffer, const Shape& input_shape, const Shape& output_shape, diff --git a/tensorflow/compiler/xla/service/gpu/fusion_merger.cc b/tensorflow/compiler/xla/service/gpu/fusion_merger.cc index 1bd88233e183af89268865e2a80155b2d7f638b6..30c1f9088968305ad0207164ecb07ba13cc89ee6 100644 --- a/tensorflow/compiler/xla/service/gpu/fusion_merger.cc +++ b/tensorflow/compiler/xla/service/gpu/fusion_merger.cc @@ -20,6 +20,7 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/strings/str_join.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h" #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h" #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -225,10 +226,11 @@ Status FusionInstructionMerger::HandleFusion(HloInstruction* fusion) { // Skip 'fusion' instruction if we cannot merge into all of its users. // Merging into all users enables the removal of 'fusion' from the // computation. - if (!absl::c_all_of(fusion->users(), [](const HloInstruction* user) { + if (!absl::c_all_of(fusion->users(), [&](const HloInstruction* user) { return user->opcode() == HloOpcode::kFusion && (user->fusion_kind() == HloInstruction::FusionKind::kLoop || - user->fusion_kind() == HloInstruction::FusionKind::kInput); + (user->fusion_kind() == HloInstruction::FusionKind::kInput && + LayoutsAreReduceInputFusionFriendly(*fusion, *user))); })) { VLOG(3) << "Not merging " << fusion->name() << ": Some of its users are not loop/input fusion kernels."; diff --git a/tensorflow/compiler/xla/service/gpu/fusion_merger.h b/tensorflow/compiler/xla/service/gpu/fusion_merger.h index 7e3f5775b8d97f43a0bba201d24f34c2d337fabb..f19996edfe3dd923aa686a19621ce28a4aed5a45 100644 --- a/tensorflow/compiler/xla/service/gpu/fusion_merger.h +++ b/tensorflow/compiler/xla/service/gpu/fusion_merger.h @@ -32,7 +32,7 @@ namespace gpu { // 2) The result of merging the fusion instruction into its users would not // increase bytes transferred. // -class FusionMerger : public HloPassInterface { +class FusionMerger : public HloModulePass { public: absl::string_view name() const override { return "fusion merger"; } diff --git a/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc b/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc index b22bb1d39ba177ef42673c7a3755694b43c15d14..7cc869ed9e89688d6ea06428a7bade3ebe55ea23 100644 --- a/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc +++ b/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc @@ -286,6 +286,39 @@ TEST_F(FusionMergerTest, WillMergeIntoInputFusion) { op::Fusion(op::Parameter())); } +TEST_F(FusionMergerTest, WillNotMergeReduceUnfriendlyLayouts) { + auto module = ParseHloString(R"( + HloModule m + + f1_computation { + f1_p0 = f32[16,16,256]{0,1,2} parameter(0) + add = f32[16,16,256]{0,1,2} add(f1_p0, f1_p0) + // Note that the copy changes the layout from {0,1,2} to {2,1,0}. + ROOT f1_root = f32[16,16,256]{2,1,0} copy(add) + } + + add_computation { + add_lhs = f32[] parameter(0) + add_rhs = f32[] parameter(1) + ROOT add_root = f32[] add(add_lhs, add_rhs) + } + + f2_computation { + f2_p0 = f32[16,16,256]{2,1,0} parameter(0) + f2_zero = f32[] constant(0) + ROOT f2_root = f32[] reduce(f2_p0, f2_zero), dimensions={0,1,2}, + to_apply=add_computation + } + + ENTRY entry { + p0 = f32[16,16,256]{0,1,2} parameter(0) + f1 = f32[16,16,256]{2,1,0} fusion(p0), kind=kLoop, calls=f1_computation + ROOT f2 = f32[] fusion(f1), kind=kInput, calls=f2_computation + })") + .ValueOrDie(); + EXPECT_FALSE(FusionMerger().Run(module.get()).ValueOrDie()); +} + } // namespace } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc index 2c02ec2584f1e04d5f98f14a4f926f34fc80932b..9c4a4903667ea1a6c99ce9e912c9d0497b8e389f 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc @@ -186,7 +186,7 @@ StatusOr DoGemmAutotune( } return InternalError( - "Unable to autotune cuBLAS gemm on stream %p; none of the %zu algorithms " + "Unable to autotune cuBLAS gemm on stream %p; none of the %u algorithms " "ran successfully", stream, algorithms.size()); } diff --git a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc index 75f414e47fe3edcc1b10b392ed5cc5038be6c190..e2ab00ce41c9e23e91449f249620d61d0f7736ae 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/service/call_graph.h" #include "tensorflow/compiler/xla/service/copy_insertion.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" @@ -27,22 +28,12 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/logging.h" namespace xla { namespace gpu { -StatusOr GpuCopyInsertion::FindOrInsertCopy( - HloInstruction* hlo) { - HloInstruction*& copy = hlo_to_copy_map_[hlo]; - if (copy == nullptr) { - TF_ASSIGN_OR_RETURN(copy, hlo->parent()->DeepCopyInstruction(hlo)); - } - return copy; -} - StatusOr GpuCopyInsertion::Run(HloModule* module) { CopyInsertion generic_copy_insertion; diff --git a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h index 8ffae18fe820aa01701731ee56a83aeacf0eab0d..4c7e38ffeb60f87a4f27e212572ae31cca8e0947 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h @@ -25,20 +25,11 @@ namespace gpu { // Besides the modifications made by the generic xla::CopyInsertion, this // GPU-specific copy insertion also materializes operands of library calls by // inserting kCopy instructions. -class GpuCopyInsertion : public HloPassInterface { +class GpuCopyInsertion : public HloModulePass { public: absl::string_view name() const override { return "copy-insertion"; } StatusOr Run(HloModule* module) override; - - protected: - // Returns a copy of `hlo`. Looks in hlo_to_copy_map_ first to avoid making - // duplicate copies. - StatusOr FindOrInsertCopy(HloInstruction* hlo); - - // A map containing all copies inserted to materialize operands of library - // calls. The key is the copied instruction and the value is the copy. - tensorflow::gtl::FlatMap hlo_to_copy_map_; }; } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc index 88be63e2679dcb145a1d7c1d3e18206c9e62a9c3..57426327822d95a42f407ed7488f35acfd3623d2 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" #include "absl/memory/memory.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" @@ -160,7 +161,7 @@ Status GpuExecutable::ExecuteThunks( if (!block_status.ok()) { return InternalError( "Failed to complete all kernels launched on stream %p: %s", - main_stream, block_status.error_message().c_str()); + main_stream, block_status.error_message()); } } @@ -197,7 +198,7 @@ GpuExecutable::ResolveConstantGlobals(se::StreamExecutor* executor) { } module_spec.AddCudaPtxInMemory(ptx().c_str()); - tensorflow::gtl::FlatMap globals; + absl::flat_hash_map globals; se::ModuleHandle module_handle; executor->LoadModule(module_spec, &module_handle); @@ -234,7 +235,7 @@ GpuExecutable::ResolveConstantGlobals(se::StreamExecutor* executor) { StatusOr GpuExecutable::ExecuteOnStream( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, HloExecutionProfile* hlo_execution_profile) { DeviceMemoryAllocator* memory_allocator = run_options->allocator(); @@ -260,10 +261,9 @@ StatusOr GpuExecutable::ExecuteOnStream( if (buffer.is_null() && buffer.size() > 0) { return FailedPrecondition( "Cannot run XLA computation because pointer to (sub-)buffer at " - "index %s of parameter %lld was null. All pointers to " - "(sub-)buffers must not be null, unless the (sub-)buffer has zero " - "elements.", - allocation.param_shape_index().ToString().c_str(), param_no); + "index %s of parameter %d was null. All pointers to (sub-)buffers " + "must not be null, unless the (sub-)buffer has zero elements.", + allocation.param_shape_index().ToString(), param_no); } buffer_allocations_builder.RegisterBuffer(i, buffer); @@ -326,7 +326,7 @@ StatusOr GpuExecutable::ExecuteOnStream( StatusOr GpuExecutable::ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments) { + absl::Span arguments) { // TODO(b/30671675): Implement asynchronous execution mode. return Unimplemented( "Asynchronous execution on stream is not yet supported on GPU."); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.h b/tensorflow/compiler/xla/service/gpu/gpu_executable.h index 627a05e2401e9f07f764988637e87773780ab1f2..0e276282e40fba0ae4881a51dad0c7c9e8d1c081 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.h @@ -19,8 +19,10 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/executable.h" @@ -34,8 +36,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -78,12 +78,12 @@ class GpuExecutable : public Executable { // match the compute capability passed to this object's constructor. StatusOr ExecuteOnStream( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, HloExecutionProfile* hlo_execution_profile) override; StatusOr ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments) override; + absl::Span arguments) override; private: // If `block_host_until_done` is false, execution will not block the host @@ -101,7 +101,7 @@ class GpuExecutable : public Executable { const PointsToSet& GetRootPointsToSet() const; using BufferAllocToDeviceMemoryMap = - tensorflow::gtl::FlatMap; + absl::flat_hash_map; // Loads the PTX or CUBIN for this executable into `executor` and resolves the // globals corresponding to constant buffers. Returns a map mapping buffer diff --git a/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc b/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc new file mode 100644 index 0000000000000000000000000000000000000000..2d31fd5570c468b0c42fa308535fd335f3588a79 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc @@ -0,0 +1,84 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h" + +#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" + +namespace xla { +namespace gpu { + +namespace { +void AppendParams(const HloInstruction& instr, + std::vector* params) { + if (instr.opcode() == HloOpcode::kFusion) { + params->insert(std::end(*params), std::begin(instr.fused_parameters()), + std::end(instr.fused_parameters())); + } else { + for (HloInstruction* operand : instr.operands()) { + params->push_back(operand); + } + } +} +} // namespace + +bool LayoutsAreReduceInputFusionFriendly(const HloInstruction& producer, + const HloInstruction& reduce) { + std::vector params; + AppendParams(producer, ¶ms); + AppendParams(reduce, ¶ms); + int64 max_rank = -1; + const Layout* max_rank_layout; + for (HloInstruction* param : params) { + if (ShapeUtil::IsArray(param->shape()) && + ShapeUtil::Rank(param->shape()) > max_rank) { + max_rank = ShapeUtil::Rank(param->shape()); + max_rank_layout = ¶m->shape().layout(); + } + } + return absl::c_all_of(params, [&](HloInstruction* param) { + return (!ShapeUtil::IsArray(param->shape())) || + (ShapeUtil::Rank(param->shape()) < max_rank) || + (LayoutUtil::Equal(param->shape().layout(), *max_rank_layout)); + }); +} + +bool IsInputFusibleReduction(const HloInstruction& instr) { + if (instr.IsMultiOutputFusion()) { + for (const HloInstruction* operand : + instr.fused_expression_root()->operands()) { + if (IsReductionToVector(*operand)) { + CHECK(instr.fusion_kind() == HloInstruction::FusionKind::kInput) + << " Multi-output fusion rooted at reduction-to-vector ops must be " + "of kind kInput: " + << instr.ToString(); + return true; + } + } + return false; + } else if (instr.opcode() == HloOpcode::kFusion) { + if (IsReductionToVector(*instr.fused_expression_root())) { + CHECK(instr.fusion_kind() == HloInstruction::FusionKind::kInput) + << " Fusion rooted at reduction-to-vector op must be of kind kInput: " + << instr.ToString(); + return true; + } + return false; + } + return IsReductionToVector(instr); +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/gpu_fusible.h b/tensorflow/compiler/xla/service/gpu/gpu_fusible.h new file mode 100644 index 0000000000000000000000000000000000000000..f7c24a0d5bbfcc61389ea19ae7f769671e4e974d --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/gpu_fusible.h @@ -0,0 +1,49 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_FUSIBLE_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_FUSIBLE_H_ + +#include "tensorflow/compiler/xla/service/hlo_instruction.h" + +// TODO(b/112957171): Extract logic to determine fusibility of HLO ops from +// GpuInstructionFusion, FusionMerger, and GpuMultiOutputFusion. + +namespace xla { +namespace gpu { + +// The code emitted for reduce-rooted input fusions (EmitReductionToVector) +// suffers from poor data locality if the layouts of input parameters differ. In +// such situtations it is better not to fuse. Only input params with +// maximum rank are considered. Params with smaller ranks will be broadcasted +// and have not been observed to cause data locality issues. +// TODO(b/111977086): Improve reduce emitters to remove this limitation. +bool LayoutsAreReduceInputFusionFriendly(const HloInstruction& producer, + const HloInstruction& reduce); + +// Whether `instr` is fusible as root of a reduce input fusions, i.e. `instr` +// is either an unfused reduction-to-vector op, an input fusion rooted at a +// reduction-to-vector op, or a multi-output input fusion with at least one +// reduction-to-vector op root. +// Note that reduction ops are lowered in different ways. Reduce input fusions +// are lowered by IrEmitterUnnested::EmitReductionToVector and must be rooted at +// reduction-to-vector ops. Other reduction ops are lowered by +// GpuElementalIrEmitter and fused like elementwise ops. +bool IsInputFusibleReduction(const HloInstruction& instr); + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_FUSIBLE_H_ diff --git a/tensorflow/compiler/xla/service/gpu/gpu_fusible_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_fusible_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..d91b7bc61fda5a07c163a07ec0e1644d2ad9db49 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/gpu_fusible_test.cc @@ -0,0 +1,332 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h" + +#include "absl/strings/str_cat.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" + +namespace xla { +namespace gpu { + +using GpuFusibleTest = HloTestBase; + +const char kModulePrefix[] = R"( + HloModule test_module + scalar_add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) + })"; + +TEST_F(GpuFusibleTest, + LayoutsAreReduceInputFusionFriendly_ElementwiseProducer) { + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( + ENTRY entry { + p0 = f32[2,2,2]{2,1,0} parameter(0) + c0 = f32[] constant(0) + exp = f32[2,2,2]{2,1,0} exponential(p0) + ROOT reduce = f32[2,2]{1,0} reduce(exp, c0), dimensions={2}, to_apply=scalar_add + })")) + .ValueOrDie(); + SCOPED_TRACE(module->ToString()); + const HloInstruction* reduce = + module->entry_computation()->root_instruction(); + ASSERT_EQ(reduce->opcode(), HloOpcode::kReduce); + const HloInstruction* exp = + module->entry_computation()->root_instruction()->operand(0); + ASSERT_EQ(exp->opcode(), HloOpcode::kExp); + EXPECT_TRUE(LayoutsAreReduceInputFusionFriendly(*exp, *reduce)); +} + +TEST_F(GpuFusibleTest, + LayoutsAreReduceInputFusionFriendly_MixedLayoutProducer) { + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( + mixed_input_layouts_computation { + p0.1 = f16[128,1024,32,32]{1,3,2,0} parameter(0) + p1.1 = f16[128,1024,32,32]{3,2,1,0} parameter(1) + copy = f16[128,1024,32,32]{1,3,2,0} copy(p1.1) + c0 = f16[] constant(0) + broadcast = f16[128,1024,32,32]{1,3,2,0} broadcast(c0), dimensions={} + greater-than = pred[128,1024,32,32]{1,3,2,0} greater-than(copy, broadcast) + ROOT root = f16[128,1024,32,32]{1,3,2,0} select(greater-than, p0.1, broadcast) + } + fused_reduce { + p0.2 = f16[128,1024,32,32]{1,3,2,0} parameter(0) + convert = f32[128,1024,32,32]{1,3,2,0} convert(p0.2) + c0.2 = f32[] constant(0) + ROOT reduce = f32[1024]{0} reduce(convert, c0.2), dimensions={0,2,3}, to_apply=scalar_add + } + ENTRY entry { + p0 = f16[128,1024,32,32]{1,3,2,0} parameter(0) + p1 = f16[128,1024,32,32]{3,2,1,0} parameter(1) + loop_fusion = f16[128,1024,32,32]{1,3,2,0} fusion(p0, p1), kind=kLoop, calls=mixed_input_layouts_computation + reduce_fusion = f32[1024]{0} fusion(loop_fusion), kind=kInput, calls=fused_reduce + ROOT root = (f32[1024]{0}, f16[128,1024,32,32]{1,3,2,0}) tuple(reduce_fusion, loop_fusion) + })")) + .ValueOrDie(); + SCOPED_TRACE(module->ToString()); + const HloInstruction* reduce_fusion = + module->entry_computation()->root_instruction()->operand(0); + ASSERT_EQ(reduce_fusion->fused_expression_root()->opcode(), + HloOpcode::kReduce); + const HloInstruction* loop_fusion = + module->entry_computation()->root_instruction()->operand(1); + ASSERT_EQ(loop_fusion->fused_expression_root()->opcode(), HloOpcode::kSelect); + EXPECT_FALSE( + LayoutsAreReduceInputFusionFriendly(*loop_fusion, *reduce_fusion)); +} + +TEST_F(GpuFusibleTest, LayoutsAreReduceInputFusionFriendly_CopyProducer) { + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( + fused_reduce { + p0.1 = f32[128,1024,32,32]{1,3,2,0} parameter(0) + c0.1 = f32[] constant(0) + ROOT reduce = f32[1024]{0} reduce(p0.1, c0.1), dimensions={0,2,3}, to_apply=scalar_add + } + ENTRY entry { + p0 = f16[128,1024,32,32]{3,2,1,0} parameter(0) + copy = f32[128,1024,32,32]{1,3,2,0} copy(p0) + ROOT reduce_fusion = f32[1024]{0} fusion(copy), kind=kInput, calls=fused_reduce + })")) + .ValueOrDie(); + SCOPED_TRACE(module->ToString()); + const HloInstruction* reduce = + module->entry_computation()->root_instruction(); + ASSERT_EQ(reduce->fused_expression_root()->opcode(), HloOpcode::kReduce); + const HloInstruction* copy = + module->entry_computation()->root_instruction()->operand(0); + ASSERT_EQ(copy->opcode(), HloOpcode::kCopy); + EXPECT_FALSE(LayoutsAreReduceInputFusionFriendly(*copy, *reduce)); +} + +TEST_F(GpuFusibleTest, + LayoutsAreReduceInputFusionFriendly_LayoutChangingFusionProducer) { + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( + layout_changing_computation { + p0.1 = f16[128,1024,32,32]{3,2,1,0} parameter(0) + p1.1 = f16[128,1024,32,32]{3,2,1,0} parameter(1) + c0 = f16[] constant(0) + broadcast = f16[128,1024,32,32]{3,2,1,0} broadcast(c0), dimensions={} + greater-than = pred[128,1024,32,32]{3,2,1,0} greater-than(p1.1, broadcast) + select = f16[128,1024,32,32]{3,2,1,0} select(greater-than, p0.1, broadcast) + ROOT root = f16[128,1024,32,32]{1,3,2,0} copy(select) + } + fused_reduce { + p0.2 = f16[128,1024,32,32]{1,3,2,0} parameter(0) + convert = f32[128,1024,32,32]{1,3,2,0} convert(p0.2) + c0.2 = f32[] constant(0) + ROOT reduce = f32[1024]{0} reduce(convert, c0.2), dimensions={0,2,3}, to_apply=scalar_add + } + ENTRY entry { + p0 = f16[128,1024,32,32]{3,2,1,0} parameter(0) + p1 = f16[128,1024,32,32]{3,2,1,0} parameter(1) + loop_fusion = f16[128,1024,32,32]{1,3,2,0} fusion(p0, p1), kind=kLoop, calls=layout_changing_computation + ROOT reduce_fusion = f32[1024]{0} fusion(loop_fusion), kind=kInput, calls=fused_reduce + })")) + .ValueOrDie(); + SCOPED_TRACE(module->ToString()); + const HloInstruction* reduce_fusion = + module->entry_computation()->root_instruction(); + ASSERT_EQ(reduce_fusion->fused_expression_root()->opcode(), + HloOpcode::kReduce); + const HloInstruction* loop_fusion = + module->entry_computation()->root_instruction()->operand(0); + ASSERT_EQ(loop_fusion->fused_expression_root()->opcode(), HloOpcode::kCopy); + EXPECT_FALSE( + LayoutsAreReduceInputFusionFriendly(*loop_fusion, *reduce_fusion)); +} + +TEST_F(GpuFusibleTest, + LayoutsAreReduceInputFusionFriendly_ConsiderMaximumRanksParamsOnly) { + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( + broadcasting_computation { + p0.1 = f32[128,1024,32,32]{1,3,2,0} parameter(0) + p1.1 = f32[128]{0} parameter(1) + broadcast = f32[128,1024,32,32]{1,3,2,0} broadcast(p1.1), dimensions={0} + ROOT add = f32[128,1024,32,32]{1,3,2,0} add(p0.1, broadcast) + } + ENTRY entry { + p0 = f16[128,1024,32,32]{1,3,2,0} parameter(0) + p1 = f16[128]{0} parameter(1) + loop_fusion = f32[128,1024,32,32]{1,3,2,0} fusion(p0, p1), kind=kLoop, calls=broadcasting_computation + c0.2 = f32[] constant(0) + ROOT reduce = f32[128,1024]{0,1} reduce(loop_fusion, c0.2), dimensions={0,2,3}, to_apply=scalar_add + })")) + .ValueOrDie(); + SCOPED_TRACE(module->ToString()); + const HloInstruction* reduce = + module->entry_computation()->root_instruction(); + ASSERT_EQ(reduce->opcode(), HloOpcode::kReduce); + const HloInstruction* loop_fusion = + module->entry_computation()->root_instruction()->operand(0); + ASSERT_EQ(loop_fusion->fused_expression_root()->opcode(), HloOpcode::kAdd); + EXPECT_TRUE(LayoutsAreReduceInputFusionFriendly(*loop_fusion, *reduce)); +} + +TEST_F(GpuFusibleTest, IsInputFusibleReduction_ReductionToVector) { + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( + ENTRY entry { + c0 = f32[] parameter(0) + p1 = f32[128,512,28,28]{3,2,1,0} parameter(1) + // Reduction-to-vector lowered by IrEmitterUnnested. + ROOT reduce = f32[512]{0} reduce(p1, c0), dimensions={0,2,3}, to_apply=scalar_add + })")) + .ValueOrDie(); + SCOPED_TRACE(module->ToString()); + const HloInstruction* reduce = + module->entry_computation()->root_instruction(); + ASSERT_EQ(reduce->opcode(), HloOpcode::kReduce); + EXPECT_TRUE(IsInputFusibleReduction(*reduce)); +} + +TEST_F(GpuFusibleTest, IsInputFusibleReduction_ElementalReduction) { + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( + ENTRY entry { + c0 = f32[] parameter(0) + p1 = f32[8,512,5,16,1,1]{5,4,3,2,1,0} parameter(1) + // Reduction lowered by GpuElementalIrEmitter. + ROOT reduce = f32[8,512,5,1,1]{4,3,2,1,0} reduce(p1, c0), dimensions={3}, to_apply=scalar_add + })")) + .ValueOrDie(); + SCOPED_TRACE(module->ToString()); + const HloInstruction* reduce = + module->entry_computation()->root_instruction(); + ASSERT_EQ(reduce->opcode(), HloOpcode::kReduce); + EXPECT_FALSE(IsInputFusibleReduction(*reduce)); +} + +TEST_F(GpuFusibleTest, IsInputFusibleReduction_SingleOutputInputReduceFusion) { + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( + fused_reduction { + c0 = f32[] parameter(0) + p1 = f32[128,512,28,28]{3,2,1,0} parameter(1) + ROOT reduce = f32[128,512]{1,0} reduce(p1, c0), dimensions={2,3}, to_apply=scalar_add + } + ENTRY entry { + p0 = f32[128,512,28,28]{3,2,1,0} parameter(0) + ROOT fusion = f32[128,512]{1,0} fusion(p0), kind=kInput, calls=fused_reduction + })")) + .ValueOrDie(); + const HloInstruction* reduce = + module->entry_computation()->root_instruction(); + ASSERT_EQ(reduce->opcode(), HloOpcode::kFusion); + EXPECT_TRUE(IsInputFusibleReduction(*reduce)); +} + +TEST_F(GpuFusibleTest, IsInputFusibleReduction_SingleOutputLoopReduceFusion) { + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( + fused_reduction { + c0 = f32[] parameter(0) + p1 = f32[8,512,5,16,1,1]{5,4,3,2,1,0} parameter(1) + ROOT reduce = f32[8,5,1,1]{3,2,1,0} reduce(p1, c0), dimensions={1,3}, to_apply=scalar_add + } + ENTRY entry { + p0 = f32[8,512,5,16,1,1]{5,4,3,2,1,0} parameter(0) + ROOT fusion = f32[8,5,1,1]{3,2,1,0} fusion(p0), kind=kLoop, calls=fused_reduction + })")) + .ValueOrDie(); + const HloInstruction* reduce = + module->entry_computation()->root_instruction(); + ASSERT_EQ(reduce->opcode(), HloOpcode::kFusion); + EXPECT_FALSE(IsInputFusibleReduction(*reduce)); +} + +TEST_F(GpuFusibleTest, IsInputFusibleReduction_MultiOutputInputReduceFusion) { + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( + fused_reduction { + c0 = f32[] parameter(0) + p1 = f32[128,512,28,28]{3,2,1,0} parameter(1) + reduce.0 = f32[128,512]{1,0} reduce(p1, c0), dimensions={2,3}, to_apply=scalar_add + reduce.1 = f32[128,512]{1,0} reduce(p1, c0), dimensions={2,3}, to_apply=scalar_add + ROOT root = (f32[128,512]{1,0}, f32[128,512]{1,0}) tuple(reduce.0, reduce.1) + } + ENTRY entry { + p0 = f32[128,512,28,28]{3,2,1,0} parameter(0) + ROOT fusion = (f32[128,512]{1,0}, f32[128,512]{1,0}) fusion(p0), kind=kInput, calls=fused_reduction + })")) + .ValueOrDie(); + const HloInstruction* reduce = + module->entry_computation()->root_instruction(); + ASSERT_EQ(reduce->opcode(), HloOpcode::kFusion); + EXPECT_TRUE(IsInputFusibleReduction(*reduce)); +} + +TEST_F(GpuFusibleTest, + IsInputFusibleReduction_MultiOutputInputReduceFusionWithExtraOutputs) { + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( + fused_reduction { + c0 = f32[] parameter(0) + p1 = f32[128,512,28,28]{3,2,1,0} parameter(1) + reduce = f32[128,512]{1,0} reduce(p1, c0), dimensions={2,3}, to_apply=scalar_add + mul = f32[128,512,28,28]{3,2,1,0} multiply(p1, p1) + ROOT root = (f32[128,512]{1,0}, f32[128,512,28,28]{3,2,1,0}) tuple(reduce, mul) + } + ENTRY entry { + p0 = f32[128,512,28,28]{3,2,1,0} parameter(0) + ROOT fusion = (f32[128,512]{1,0}, f32[128,512,28,28]{3,2,1,0}) fusion(p0), kind=kInput, calls=fused_reduction + })")) + .ValueOrDie(); + const HloInstruction* reduce = + module->entry_computation()->root_instruction(); + ASSERT_EQ(reduce->opcode(), HloOpcode::kFusion); + EXPECT_TRUE(IsInputFusibleReduction(*reduce)); +} + +TEST_F(GpuFusibleTest, IsInputFusibleReduction_MultiOutputLoopReduceFusion) { + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( + fused_reduction { + c0 = f32[] parameter(0) + p1 = f32[128,512,28,28]{3,2,1,0} parameter(1) + reduce.0 = f32[512,28]{1,0} reduce(p1, c0), dimensions={0,2}, to_apply=scalar_add + reduce.1 = f32[512,28]{1,0} reduce(p1, c0), dimensions={0,2}, to_apply=scalar_add + ROOT root = (f32[512,28]{1,0}, f32[512,28]{1,0}) tuple(reduce.0, reduce.1) + } + ENTRY entry { + p0 = f32[128,512,28,28]{3,2,1,0} parameter(0) + ROOT fusion = (f32[512,28]{1,0}, f32[512,28]{1,0}) fusion(p0), kind=kLoop, calls=fused_reduction + })")) + .ValueOrDie(); + const HloInstruction* reduce = + module->entry_computation()->root_instruction(); + ASSERT_EQ(reduce->opcode(), HloOpcode::kFusion); + EXPECT_FALSE(IsInputFusibleReduction(*reduce)); +} + +TEST_F(GpuFusibleTest, + IsInputFusibleReduction_MultiOutputLoopFusionReduceAndElementwiseOp) { + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( + fused_reduction { + c0 = f32[] parameter(0) + p1 = f32[128,512,28,28]{3,2,1,0} parameter(1) + reduce = f32[512,28]{1,0} reduce(p1, c0), dimensions={0,2}, to_apply=scalar_add + mul = f32[128,512,28,28]{3,2,1,0} multiply(p1, p1) + ROOT root = (f32[512,28]{1,0}, f32[128,512,28,28]{3,2,1,0}) tuple(reduce, mul) + } + ENTRY entry { + p0 = f32[128,512,28,28]{3,2,1,0} parameter(0) + ROOT fusion = (f32[512,28]{1,0}, f32[128,512,28,28]{3,2,1,0}) fusion(p0), kind=kLoop, calls=fused_reduction + })")) + .ValueOrDie(); + const HloInstruction* reduce = + module->entry_computation()->root_instruction(); + ASSERT_EQ(reduce->opcode(), HloOpcode::kFusion); + EXPECT_FALSE(IsInputFusibleReduction(*reduce)); +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc similarity index 94% rename from tensorflow/compiler/xla/service/gpu/hlo_schedule.cc rename to tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc index 76055ff009c05499ecfbfce31d87c65f3e39785d..02a0d028c118aba23996f9b97d05443bb4a00c88 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc @@ -17,12 +17,13 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/service/gpu/hlo_schedule.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h" #include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/buffer_value.h" +#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h" #include "tensorflow/compiler/xla/service/hlo_reachability.h" -#include "tensorflow/compiler/xla/service/hlo_scheduling.h" +#include "tensorflow/compiler/xla/service/hlo_schedule.h" #include "tensorflow/compiler/xla/types.h" namespace xla { @@ -184,13 +185,13 @@ void BFSLaunchOrder(const HloComputation* computation, } // end namespace -HloSchedule::HloSchedule() {} +GpuHloSchedule::GpuHloSchedule() {} /* static */ -StatusOr> HloSchedule::Build( +StatusOr> GpuHloSchedule::Build( const HloModule& module, const StreamAssignment& stream_assignment, int64 pointer_size) { - std::unique_ptr schedule(new HloSchedule); + std::unique_ptr schedule(new GpuHloSchedule); // Initialize thunk_launch_order_, the total order of thunk launches. const HloComputation* entry_computation = module.entry_computation(); @@ -198,11 +199,12 @@ StatusOr> HloSchedule::Build( // All kernels are launched on a single stream, so there's no loss of // concurrency by optimizing for minimal memory usage. TF_ASSIGN_OR_RETURN( - schedule->thunk_launch_order_, - ScheduleOneComputation( + HloInstructionSequence sequence, + ScheduleComputation( *entry_computation, [pointer_size](const BufferValue& buffer) { return ShapeUtil::ByteSizeOf(buffer.shape(), pointer_size); })); + schedule->thunk_launch_order_ = sequence.instructions(); } else { // BFS tends to increase concurrency, but also increases memory usage. BFSLaunchOrder(entry_computation, &schedule->thunk_launch_order_); diff --git a/tensorflow/compiler/xla/service/gpu/hlo_schedule.h b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h similarity index 78% rename from tensorflow/compiler/xla/service/gpu/hlo_schedule.h rename to tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h index 1ce7a48ac8fcbbad0b3697845681582fe806b322..07a7fc67aa555845c3de57e574ab582403ec0490 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_schedule.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HLO_SCHEDULE_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HLO_SCHEDULE_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_HLO_SCHEDULE_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_HLO_SCHEDULE_H_ #include #include @@ -33,12 +33,14 @@ namespace gpu { // launches, because thunks may be scheduled onto concurrent streams. This // schedule is used by BufferAssigner to determine buffer liveness (i.e. to // minimize allocations), and also by ThunkSchedule to determine the thunk -// launch order. -class HloSchedule { +// launch order. This class differs from xla::HloSchedule in that HloSchedule +// represents a total order of all instructions in the module for backends which +// execute HLO instructions strictly sequentially. +class GpuHloSchedule { public: - // Constructs an HloSchedule for the given module, based on the given stream - // assignment. - static StatusOr> Build( + // Constructs an GpuHloSchedule for the given module, based on the given + // stream assignment. + static StatusOr> Build( const HloModule& module, const StreamAssignment& stream_assignment, int64 pointer_size); @@ -56,7 +58,7 @@ class HloSchedule { } private: - HloSchedule(); + GpuHloSchedule(); std::vector thunk_launch_order_; std::unique_ptr hlo_ordering_; @@ -65,4 +67,4 @@ class HloSchedule { } // namespace gpu } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HLO_SCHEDULE_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_HLO_SCHEDULE_H_ diff --git a/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc similarity index 86% rename from tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc rename to tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc index d4a96cd5b353436ea4d1d6db3810b3e777449cd4..b857fa775a76ec999b505a2a64332cc0c54cf00b 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/gpu/hlo_schedule.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h" #include #include @@ -24,22 +24,23 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/types.h" namespace xla { namespace gpu { -class HloScheduleTest : public HloTestBase { +class GpuHloScheduleTest : public HloVerifiedTestBase { protected: using HloVec = std::vector; // Pre-canned shapes. Shape f32_2x2_ = ShapeUtil::MakeShape(F32, {2, 2}); - static std::unique_ptr BuildHloSchedule( + static std::unique_ptr BuildGpuHloSchedule( const HloModule& module, const StreamAssignment& streams) { - return HloSchedule::Build(module, streams, /*pointer_size=*/8) + return GpuHloSchedule::Build(module, streams, /*pointer_size=*/8) .ConsumeValueOrDie(); } @@ -65,7 +66,7 @@ class HloScheduleTest : public HloTestBase { // Test of a single stream, where data dependencies fully determine the // execution order. -TEST_F(HloScheduleTest, SequentialMatMul) { +TEST_F(GpuHloScheduleTest, SequentialMatMul) { HloComputation::Builder builder("entry_computation"); HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter( /*parameter_number=*/0, f32_2x2_, /*name=*/"x")); @@ -73,10 +74,10 @@ TEST_F(HloScheduleTest, SequentialMatMul) { /*parameter_number=*/1, f32_2x2_, /*name=*/"y")); HloInstruction* z = builder.AddInstruction(HloInstruction::CreateParameter( /*parameter_number=*/2, f32_2x2_, /*name=*/"z")); - HloInstruction* dot1 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, x, y)); - HloInstruction* dot2 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, dot1, z)); + HloInstruction* dot1 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, x, y)); + HloInstruction* dot2 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, dot1, z)); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build(dot2)); @@ -85,7 +86,7 @@ TEST_F(HloScheduleTest, SequentialMatMul) { EXPECT_EQ(streams->StreamNumberForHlo(*dot1), streams->StreamNumberForHlo(*dot2)); - auto schedule = BuildHloSchedule(*module, *streams); + auto schedule = BuildGpuHloSchedule(*module, *streams); // Remove parameters, which are unordered. EXPECT_EQ(RemoveHlo(schedule->ThunkLaunchOrder(), {x, y, z}), HloVec({dot1, dot2})); @@ -123,7 +124,7 @@ TEST_F(HloScheduleTest, SequentialMatMul) { // Test of a single stream, where data dependencies do not fully determine the // execution order, but the stream assignment does. -TEST_F(HloScheduleTest, SequentialAdd) { +TEST_F(GpuHloScheduleTest, SequentialAdd) { HloComputation::Builder builder("entry_computation"); HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter( /*parameter_number=*/0, f32_2x2_, /*name=*/"x")); @@ -147,7 +148,7 @@ TEST_F(HloScheduleTest, SequentialAdd) { EXPECT_EQ(streams->StreamNumberForHlo(*add1), streams->StreamNumberForHlo(*add3)); - auto schedule = BuildHloSchedule(*module, *streams); + auto schedule = BuildGpuHloSchedule(*module, *streams); // Remove parameters, which are unordered. EXPECT_EQ(RemoveHlo(schedule->ThunkLaunchOrder(), {x, y, z}), HloVec({add1, add2, add3})); @@ -195,18 +196,18 @@ TEST_F(HloScheduleTest, SequentialAdd) { } // Test of two streams. -TEST_F(HloScheduleTest, ConcurrentMatMul) { +TEST_F(GpuHloScheduleTest, ConcurrentMatMul) { HloComputation::Builder builder("entry_computation"); HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter( /*parameter_number=*/0, f32_2x2_, /*name=*/"x")); HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter( /*parameter_number=*/1, f32_2x2_, /*name=*/"y")); - HloInstruction* dot1 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, x, y)); - HloInstruction* dot2 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, y, x)); - HloInstruction* add = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, dot1, dot2)); + HloInstruction* dot1 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, x, y)); + HloInstruction* dot2 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, y, x)); + HloInstruction* add = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, dot1, dot2)); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build(add)); @@ -215,7 +216,7 @@ TEST_F(HloScheduleTest, ConcurrentMatMul) { EXPECT_NE(streams->StreamNumberForHlo(*dot1), streams->StreamNumberForHlo(*dot2)); - auto schedule = BuildHloSchedule(*module, *streams); + auto schedule = BuildGpuHloSchedule(*module, *streams); // Remove parameters, which are unordered. HloVec thunk_launch_order = RemoveHlo(schedule->ThunkLaunchOrder(), {x, y}); EXPECT_TRUE(thunk_launch_order == HloVec({dot1, dot2, add}) || @@ -251,7 +252,7 @@ TEST_F(HloScheduleTest, ConcurrentMatMul) { } // Test of multiple streams. -TEST_F(HloScheduleTest, LatticeMatMul) { +TEST_F(GpuHloScheduleTest, LatticeMatMul) { // d00 -- layer 0 // / \ // d10 d11 -- layer 1 @@ -266,26 +267,26 @@ TEST_F(HloScheduleTest, LatticeMatMul) { params.reserve(6); for (int i = 0; i < 6; ++i) { params.push_back(builder.AddInstruction(HloInstruction::CreateParameter( - i, f32_2x2_, /*name=*/tensorflow::strings::Printf("param%d", i)))); + i, f32_2x2_, /*name=*/absl::StrFormat("param%d", i)))); } HloInstruction* d00 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, params[2], params[3])); - HloInstruction* d10 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, params[1], d00)); - HloInstruction* d11 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, d00, params[4])); - HloInstruction* d20 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, params[0], d10)); - HloInstruction* d21 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, d10, d11)); - HloInstruction* d22 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, d11, params[5])); - HloInstruction* d30 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, d20, d21)); - HloInstruction* d31 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, d21, d22)); - HloInstruction* d40 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, d30, d31)); + CreateCanonicalDot(f32_2x2_, params[2], params[3])); + HloInstruction* d10 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, params[1], d00)); + HloInstruction* d11 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d00, params[4])); + HloInstruction* d20 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, params[0], d10)); + HloInstruction* d21 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d10, d11)); + HloInstruction* d22 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d11, params[5])); + HloInstruction* d30 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d20, d21)); + HloInstruction* d31 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d21, d22)); + HloInstruction* d40 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d30, d31)); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build(d40)); @@ -307,7 +308,7 @@ TEST_F(HloScheduleTest, LatticeMatMul) { // We don't check the thunk launch order, since there are many valid total // orders, and it's annoying to express. - auto schedule = BuildHloSchedule(*module, *streams); + auto schedule = BuildGpuHloSchedule(*module, *streams); auto order = schedule->ConsumeHloOrdering(); const HloVec all_params( diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.cc b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.cc index 4944c41f7d8dc7a78a3cd094aee4d7087c74857e..4268fb2c7a813b3b53e4cd48746028a7b369f28e 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.cc @@ -34,9 +34,8 @@ StatusOr GpuHloSupportChecker::Run(HloModule* module) { return xla::Unimplemented( "GPU backend does not support HLO instruction %s with shape " "containing a sparse layout: %s", - instruction->ToString().c_str(), - ShapeUtil::HumanStringWithLayout(instruction->shape()) - .c_str()); + instruction->ToString(), + ShapeUtil::HumanStringWithLayout(instruction->shape())); } return Status::OK(); })); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h index bbb3340760c8330bd6570f33382f004315c6d0bd..9c64b4d10c9d1b172f7bd89b5fdacda893488bf8 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h @@ -23,7 +23,7 @@ namespace xla { // his pass should run early in the HLO pipeline and checks for HLO constructs // which are not supported by the GPU backend and cannot be removed via HLO // transformations (eg, sparse layouts). -class GpuHloSupportChecker : public HloPassInterface { +class GpuHloSupportChecker : public HloModulePass { public: GpuHloSupportChecker() = default; ~GpuHloSupportChecker() override = default; diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker_test.cc index 0a4089df4c954cafcbe241189ee79a0995683513..7d01eeb02567d710e9de089c7f29ffcc5f959f9a 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker_test.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker_test.cc @@ -16,7 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/core/lib/core/error_codes.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -25,7 +25,7 @@ namespace { using ::testing::HasSubstr; -class GpuHloSupportCheckerTest : public HloTestBase { +class GpuHloSupportCheckerTest : public HloVerifiedTestBase { protected: GpuHloSupportChecker& checker() { return checker_; } @@ -45,7 +45,7 @@ TEST_F(GpuHloSupportCheckerTest, Add) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - TF_ASSERT_OK(checker().Run(module.get()).status()); + TF_ASSERT_OK(checker().Run(module).status()); } TEST_F(GpuHloSupportCheckerTest, SparseUnimplemented) { @@ -57,7 +57,10 @@ TEST_F(GpuHloSupportCheckerTest, SparseUnimplemented) { HloInstruction::CreateParameter(1, sparse_shape, "param1")); builder.AddInstruction(HloInstruction::CreateBinary( sparse_shape, HloOpcode::kAdd, param0, param1)); - auto module = CreateNewModule(); + // Since verifier is reporting sparse layouts as errors, we should + // use a regular HloModule instead of VerifiedHloModule to avoid + // verifier errors being triggered in the destructor. + auto module = HloTestBase::CreateNewModule(); module->AddEntryComputation(builder.Build()); Status status = checker().Run(module.get()).status(); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc index d033faee8d25ed81a1483f8314652ef999ab36c5..1c0a23fa3eb38961d420aff05e412c3b4d8524e7 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc @@ -18,11 +18,12 @@ limitations under the License. #include #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/service/gpu/gpu_options.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -90,45 +91,46 @@ HeuristicLayoutAssignment(const HloInstruction* instr, // operands and the output shape. Depending on the underlying algorithm, one of // { NCHW, NHWC } ^ 3 = 8 different layout combinations may be chosen. Status GpuLayoutAssignment::AddBackendConstraintsToDnnConvCustomCall( - HloInstruction* instr, LayoutConstraints* constraints) { - CHECK(IsCustomCallToDnnConvolution(*instr)) << instr->ToString(); - Shape input_shape; - Shape filter_shape; - Shape output_shape; - const auto& target = instr->custom_call_target(); - if (target == kCudnnConvForwardCallTarget) { - input_shape = instr->operand(0)->shape(); - filter_shape = instr->operand(1)->shape(); - output_shape = instr->shape().tuple_shapes(0); - } else if (target == kCudnnConvBackwardInputCallTarget) { - input_shape = instr->shape().tuple_shapes(0); - filter_shape = instr->operand(1)->shape(); - output_shape = instr->operand(0)->shape(); - } else if (target == kCudnnConvBackwardFilterCallTarget) { - input_shape = instr->operand(0)->shape(); - filter_shape = instr->shape().tuple_shapes(0); - output_shape = instr->operand(1)->shape(); - } else { - LOG(FATAL) << "Unexpected custom call target: " - << instr->custom_call_target(); + HloCustomCallInstruction* instr, LayoutConstraints* constraints) { + Shape lhs_shape = instr->operand(0)->shape(); + Shape rhs_shape = instr->operand(1)->shape(); + Shape result_shape = instr->shape().tuple_shapes(0); + + Shape* input_shape; + Shape* filter_shape; + Shape* output_shape; + + TF_ASSIGN_OR_RETURN(auto kind, GetCudnnConvKind(instr)); + switch (kind) { + case CudnnConvKind::kForward: + case CudnnConvKind::kForwardActivation: + input_shape = &lhs_shape; + filter_shape = &rhs_shape; + output_shape = &result_shape; + break; + case CudnnConvKind::kBackwardInput: + input_shape = &result_shape; + filter_shape = &rhs_shape; + output_shape = &lhs_shape; + break; + case CudnnConvKind::kBackwardFilter: + input_shape = &lhs_shape; + filter_shape = &result_shape; + output_shape = &rhs_shape; + break; } { DataLayout input; FilterLayout filter; DataLayout output; - if (ConvUseLayoutHeuristic(instr->GetModule()->config())) { - std::tie(input, filter, output) = - HeuristicLayoutAssignment(instr, stream_executor_); - } else { - input = DataLayout::kBatchDepthYX; - filter = FilterLayout::kOutputInputYX; - output = DataLayout::kBatchDepthYX; - } + std::tie(input, filter, output) = + HeuristicLayoutAssignment(instr, stream_executor_); TF_ASSIGN_OR_RETURN( - std::tie(*input_shape.mutable_layout(), *filter_shape.mutable_layout(), - *output_shape.mutable_layout()), + std::tie(*input_shape->mutable_layout(), + *filter_shape->mutable_layout(), + *output_shape->mutable_layout()), StreamExecutorConvLayoutsToXlaLayouts( instr->convolution_dimension_numbers(), input, filter, output)); } @@ -141,24 +143,23 @@ Status GpuLayoutAssignment::AddBackendConstraintsToDnnConvCustomCall( instr, /*index=*/{0})); // Set layouts of the instructions' shapes. - if (target == kCudnnConvForwardCallTarget) { - TF_RETURN_IF_ERROR(constraints->SetOperandLayout(input_shape, instr, 0)); - TF_RETURN_IF_ERROR(constraints->SetOperandLayout(filter_shape, instr, 1)); - TF_RETURN_IF_ERROR( - constraints->SetBufferLayout(output_shape.layout(), *call_result_buf)); - } else if (target == kCudnnConvBackwardInputCallTarget) { - TF_RETURN_IF_ERROR(constraints->SetOperandLayout(output_shape, instr, 0)); - TF_RETURN_IF_ERROR(constraints->SetOperandLayout(filter_shape, instr, 1)); - TF_RETURN_IF_ERROR( - constraints->SetBufferLayout(input_shape.layout(), *call_result_buf)); - } else if (target == kCudnnConvBackwardFilterCallTarget) { - TF_RETURN_IF_ERROR(constraints->SetOperandLayout(input_shape, instr, 0)); - TF_RETURN_IF_ERROR(constraints->SetOperandLayout(output_shape, instr, 1)); - TF_RETURN_IF_ERROR( - constraints->SetBufferLayout(filter_shape.layout(), *call_result_buf)); - } else { - LOG(FATAL) << "Unexpected custom call target: " - << instr->custom_call_target(); + TF_RETURN_IF_ERROR(constraints->SetOperandLayout(lhs_shape, instr, 0)); + TF_RETURN_IF_ERROR(constraints->SetOperandLayout(rhs_shape, instr, 1)); + TF_RETURN_IF_ERROR( + constraints->SetBufferLayout(result_shape.layout(), *call_result_buf)); + // instr->operand(2), if exists, is the bias buffer. There is no need to + // assign layout to it, as it has only one dimension. + + // instr->opernad(3), if exists, is the side input buffer. + if (instr->operand_count() == 4) { + if (kind != CudnnConvKind::kForwardActivation) { + return InternalError( + "Invalid convolution. Conv has a side input, but kind is not fused " + "conv forward: %s", + instr->ToString()); + } + // The side input layout must match the output layout. + TF_RETURN_IF_ERROR(constraints->SetOperandLayout(*output_shape, instr, 3)); } return Status::OK(); } @@ -173,8 +174,8 @@ Status GpuLayoutAssignment::AddBackendConstraints( ++iterator) { HloInstruction* instruction = *iterator; if (IsCustomCallToDnnConvolution(*instruction)) { - TF_RETURN_IF_ERROR( - AddBackendConstraintsToDnnConvCustomCall(instruction, constraints)); + TF_RETURN_IF_ERROR(AddBackendConstraintsToDnnConvCustomCall( + Cast(instruction), constraints)); } // For batched dot we require the default layout. @@ -207,21 +208,37 @@ Status GpuLayoutAssignment::AddBackendConstraints( constraints->SetOperandLayout(op1_shape, instruction, 1)); TF_RETURN_IF_ERROR( constraints->SetInstructionLayout(output_shape, instruction)); + } else if (instruction->opcode() == HloOpcode::kSort && + ShapeUtil::Rank(instruction->operand(0)->shape()) > 1) { + // Make sure that all the operands and the output(s) have the same layout. + Shape keys_shape = instruction->operand(0)->shape(); + Layout keys_layout = + LayoutUtil::GetDefaultLayoutForRank(ShapeUtil::Rank(keys_shape)); + for (int64 i = 0; i < instruction->operand_count(); ++i) { + Shape shape = instruction->operand(i)->shape(); + *shape.mutable_layout() = keys_layout; + TF_RETURN_IF_ERROR( + constraints->SetOperandLayout(shape, instruction, i)); + const LogicalBuffer* output_buffer; + if (ShapeUtil::IsArray(instruction->shape())) { + TF_ASSIGN_OR_RETURN( + output_buffer, + constraints->points_to_analysis().GetBufferDefinedAt(instruction, + {})); + } else { + TF_ASSIGN_OR_RETURN( + output_buffer, + constraints->points_to_analysis().GetBufferDefinedAt(instruction, + {i})); + } + TF_RETURN_IF_ERROR( + constraints->SetBufferLayout(keys_layout, *output_buffer)); + } } } return Status::OK(); } -bool GpuLayoutAssignment::CustomCallRequiresMajorFirstLayout( - const HloInstruction* instruction) { - // - Inputs to cudnn batchnorm custom calls don't need the major-first layout - // (i.e. {n, n-1, ...0}) -- we can handle any layout. - // - Inputs to cudnn convolution require custom layouts handled in - // AddBackendConstraints. - return !IsCustomCallToDnnBatchNorm(*instruction) && - !IsCustomCallToDnnConvolution(*instruction); -} - Status GpuLayoutAssignment::PropagateOperandConstraint( const OperandLayoutConstraint& layout_constraint, LayoutConstraints* constraints) { diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h index ce24af1cf8856920ccf438b5bbd2ef28cfa8ba6f..6a48e55fd2e784f80a50f4565107db177fb43bfc 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_LAYOUT_ASSIGNMENT_H_ #include "tensorflow/compiler/xla/service/computation_layout.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/layout_assignment.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -29,8 +30,11 @@ namespace gpu { class GpuLayoutAssignment : public LayoutAssignment { public: explicit GpuLayoutAssignment(ComputationLayout* entry_computation_layout, + std::function + instruction_can_change_layout_func, se::StreamExecutor* stream_executor) - : LayoutAssignment(entry_computation_layout), + : LayoutAssignment(entry_computation_layout, + std::move(instruction_can_change_layout_func)), stream_executor_(stream_executor) {} ~GpuLayoutAssignment() override {} @@ -42,12 +46,10 @@ class GpuLayoutAssignment : public LayoutAssignment { Status PropagateBufferConstraint( const BufferLayoutConstraint& buffer_constraint, LayoutConstraints* constraints) override; - bool CustomCallRequiresMajorFirstLayout( - const HloInstruction* instruction) override; private: Status AddBackendConstraintsToDnnConvCustomCall( - HloInstruction* instr, LayoutConstraints* constraints); + HloCustomCallInstruction* instr, LayoutConstraints* constraints); se::StreamExecutor* stream_executor_; }; diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc index fbc8ddf599570b90e93eb463a1fd6c275b73711c..4822b820f4e229336e2b26cfbd0097c8c31a50c8 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc @@ -75,7 +75,8 @@ TEST_F(LayoutAssignmentTest, Elementwise) { ShapeLayout(result_shape_with_layout); GpuLayoutAssignment layout_assignment( - &computation_layout, backend().default_stream_executor()); + &computation_layout, LayoutAssignment::InstructionCanChangeLayout, + backend().default_stream_executor()); EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie()); for (const HloInstruction* operand : add->operands()) { @@ -163,7 +164,8 @@ TEST_F(LayoutAssignmentTest, BatchNormInference) { } GpuLayoutAssignment layout_assignment( - &computation_layout, backend().default_stream_executor()); + &computation_layout, LayoutAssignment::InstructionCanChangeLayout, + backend().default_stream_executor()); EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie()); // The first operand to batchnorm should have the same layout as the @@ -233,7 +235,8 @@ TEST_F(LayoutAssignmentTest, BatchNormTraining) { } GpuLayoutAssignment layout_assignment( - &computation_layout, backend().default_stream_executor()); + &computation_layout, LayoutAssignment::InstructionCanChangeLayout, + backend().default_stream_executor()); EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie()); // The first operand to batchnorm should have the same layout as the @@ -314,7 +317,8 @@ TEST_F(LayoutAssignmentTest, BatchNormGrad) { } GpuLayoutAssignment layout_assignment( - &computation_layout, backend().default_stream_executor()); + &computation_layout, LayoutAssignment::InstructionCanChangeLayout, + backend().default_stream_executor()); EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie()); // The first and fourth operands to the batchnorm call should have the @@ -347,9 +351,11 @@ TEST_F(LayoutAssignmentTest, DotLayout) { ParseHloString(hlo_text)); ComputationLayout computation_layout( - module->entry_computation()->ComputeProgramShape()); - GpuLayoutAssignment layout_assignment(&computation_layout, - backend().default_stream_executor()); + module->entry_computation()->ComputeProgramShape(), + /*ignore_layouts=*/false); + GpuLayoutAssignment layout_assignment( + &computation_layout, LayoutAssignment::InstructionCanChangeLayout, + backend().default_stream_executor()); EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie()); Shape expected_shape = @@ -359,6 +365,34 @@ TEST_F(LayoutAssignmentTest, DotLayout) { op::ShapeWithLayout(expected_shape))); } +TEST_F(LayoutAssignmentTest, SortLayout) { + const char* hlo_text = R"( + HloModule SortLayout + ENTRY sort { + keys = f32[3,2]{0,1} constant(f32[3,2]{0,1}{{0,1},{0,1},{0,1}}) + values = f32[2,3]{1,0} parameter(0) + transpose = f32[3,2]{1,0} transpose(values), dimensions={1,0} + ROOT sort = (f32[3,2]{1,0}, f32[3,2]{1,0}) sort(keys, transpose), + dimensions={1} + })"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_text)); + + ComputationLayout computation_layout( + module->entry_computation()->ComputeProgramShape(), + /*ignore_layouts=*/false); + GpuLayoutAssignment layout_assignment( + &computation_layout, LayoutAssignment::InstructionCanChangeLayout, + backend().default_stream_executor()); + EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie()); + + Shape expected_shape = ShapeUtil::MakeShapeWithLayout(F32, {3, 2}, {1, 0}); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Sort(op::ShapeWithLayout(expected_shape), + op::ShapeWithLayout(expected_shape))); +} + } // namespace } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc index 44303724bb5cda4f392c8d17d60c114286b6b7e2..f3c274429242d5c989146d14ea523b5910408cff 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc @@ -84,7 +84,7 @@ Status GpuTransferManager::EnqueueBuffersToInfeed( Status block_status = stream->BlockHostUntilDone(); if (!block_status.ok()) { return InternalError("Failed to complete data transfer on stream %p: %s", - stream, block_status.error_message().c_str()); + stream, block_status.error_message()); } infeed_manager->EnqueueDestination(std::move(buffers)); @@ -97,7 +97,7 @@ Status GpuTransferManager::EnqueueBuffersToInfeed( StatusOr GpuTransferManager::TransferBufferToInfeedInternal( se::StreamExecutor* executor, int64 size, const void* source) { if (size > std::numeric_limits::max()) { - return InvalidArgument("Infeed shape is too large: needs %lld bytes", size); + return InvalidArgument("Infeed shape is too large: needs %d bytes", size); } if (size == 0) { diff --git a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc index 0e205b9c028dee91b422bd9f18a1c128d54e15f8..51627402b45f594dab3480129ba182d54d01b811 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc +++ b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc @@ -35,8 +35,8 @@ using absl::StrAppend; using absl::StrCat; void HloToIrBindings::EmitBasePointersForHlos( - tensorflow::gtl::ArraySlice io_hlos, - tensorflow::gtl::ArraySlice non_io_hlos) { + absl::Span io_hlos, + absl::Span non_io_hlos) { // I/O HLOs are bound to the arguments of the current IR function. I.e., // // void IrFunction(io_0, io_1, ..., io_{m-1}, temp_buffer_base) { diff --git a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h index eee40b0e91fc03013a6978ae3cfe42b87633eed7..c0edae530cedba45c897b07b7b9cc72eaaab397c 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h +++ b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "absl/types/span.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Value.h" #include "tensorflow/compiler/xla/map_util.h" @@ -25,7 +26,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h" #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" -#include "tensorflow/core/lib/gtl/array_slice.h" namespace xla { namespace gpu { @@ -45,8 +45,8 @@ class HloToIrBindings { alias_analysis_(module, *buffer_assignment_, &b_->getContext()) {} void EmitBasePointersForHlos( - tensorflow::gtl::ArraySlice io_hlos, - tensorflow::gtl::ArraySlice non_io_hlos); + absl::Span io_hlos, + absl::Span non_io_hlos); // Rebinds the given HLO to the LLVM IR value that represent its address. void BindHloToIrValue(const HloInstruction& hlo, llvm::Value* ir_value, diff --git a/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc b/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc index fee6d2af3bfd4976f5845edf592e8310b55a3feb..8c3a026740851767855beae59d6a3c92f7a0d6bd 100644 --- a/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc @@ -96,7 +96,7 @@ Status InfeedThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, Status block_status = stream->BlockHostUntilDone(); if (!block_status.ok()) { return InternalError("Failed to complete data transfer on stream %p: %s", - stream, block_status.error_message().c_str()); + stream, block_status.error_message()); } VLOG(2) << "Infeeding to GPU complete"; diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc index 0f2c83aeb2633a007559d8caac78ea2d233539ed..1d66787d8927ad818cbc66d19429c1816fc51748 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h" +#include "absl/container/flat_hash_set.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/pattern_matcher.h" @@ -26,7 +28,7 @@ namespace gpu { namespace { -bool IsFusile(const HloInstruction& hlo) { +bool IsFusible(const HloInstruction& hlo) { // Don't fuse get-tuple-element on GPU: We can, but it's slower than not // fusing. We never generate kernels for unfused GTEs. Instead, if an // unfused GTE is an input to a kernel (including a fusion kernel), we @@ -41,10 +43,11 @@ bool IsFusile(const HloInstruction& hlo) { hlo.opcode() == HloOpcode::kDynamicUpdateSlice || hlo.opcode() == HloOpcode::kFusion || hlo.opcode() == HloOpcode::kGather || - hlo.opcode() == HloOpcode::kPad || + hlo.opcode() == HloOpcode::kIota || hlo.opcode() == HloOpcode::kPad || hlo.opcode() == HloOpcode::kReduce || hlo.opcode() == HloOpcode::kReduceWindow || hlo.opcode() == HloOpcode::kReshape || + hlo.opcode() == HloOpcode::kScatter || hlo.opcode() == HloOpcode::kSlice || hlo.opcode() == HloOpcode::kTranspose; } @@ -124,8 +127,8 @@ bool IsIEEEFloatingPointScalarConstant(const HloInstruction* constant) { } // Compute the precise number of operands to the new fusion. - tensorflow::gtl::FlatSet operands( - a->operands().begin(), a->operands().end()); + absl::flat_hash_set operands(a->operands().begin(), + a->operands().end()); operands.insert(b->operands().begin(), b->operands().end()); // If there's an edge between `a` and `b`, don't count it: We're fusing that // producer -> consumer relationship. @@ -221,6 +224,18 @@ bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer, return false; } + // Scatter is only supported at the root of a kInput fusion. + if (producer->opcode() == HloOpcode::kScatter) { + return false; + } + + // Do not fuse into reduce input fusions if the resulting kernel would suffer + // from poor data locality (due to unfriendly input layouts). + if (IsInputFusibleReduction(*consumer) && + !LayoutsAreReduceInputFusionFriendly(*producer, *consumer)) { + return false; + } + // We can't fuse library calls, so if a user of such an op could become a // bitcast, leave it unfused. See `xla::InstructionFusion::ShouldFuse` for // further rationale. @@ -245,7 +260,7 @@ bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer, return true; } - if (!IsFusile(*producer) || !IsFusile(*consumer) || + if (!IsFusible(*producer) || !IsFusible(*consumer) || !InstructionFusion::ShouldFuse(consumer, operand_index)) { return false; } @@ -276,7 +291,8 @@ bool GpuInstructionFusion::ShouldFuseIntoMultiOutput(HloInstruction* consumer, HloInstruction::FusionKind GpuInstructionFusion::ChooseKind( const HloInstruction* producer, const HloInstruction* consumer) { - if (IsReductionToVector(*consumer)) { + if (IsReductionToVector(*consumer) || + consumer->opcode() == HloOpcode::kScatter) { return HloInstruction::FusionKind::kInput; } if (producer->opcode() == HloOpcode::kDot || diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc index 8d0522bd8fd6659e64d18c52807df8dc7fc2f3b8..fd9b7cee80bdad9a8ed625872ae68bede10200b3 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/util.h" namespace op = xla::testing::opcode_matchers; @@ -111,8 +112,8 @@ TEST_F(InstructionFusionTest, PotentialBitcastReshapeOfDotUnfused) { HloComputation::Builder builder(TestName()); auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(S32, {1, 1}), "0")); - auto dot1 = builder.AddInstruction(HloInstruction::CreateCanonicalDot( - ShapeUtil::MakeShape(S32, {1, 1}), param0, param0)); + auto dot1 = builder.AddInstruction( + CreateCanonicalDot(ShapeUtil::MakeShape(S32, {1, 1}), param0, param0)); auto reshape2 = builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(S32, {1, 1, 1}), dot1)); @@ -128,8 +129,8 @@ TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfDotUnfused) { HloComputation::Builder builder(TestName()); auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(S32, {1, 1}), "0")); - auto dot1 = builder.AddInstruction(HloInstruction::CreateCanonicalDot( - ShapeUtil::MakeShape(S32, {1, 1}), param0, param0)); + auto dot1 = builder.AddInstruction( + CreateCanonicalDot(ShapeUtil::MakeShape(S32, {1, 1}), param0, param0)); auto transpose2 = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(S32, {1, 1}), dot1, {0, 1})); @@ -171,6 +172,78 @@ TEST_F(InstructionFusionTest, BroadcastIntoReduce) { op::Reduce(op::Broadcast(op::Constant()), op::Constant())); } +TEST_F(InstructionFusionTest, DoNotFuseLayoutChangingOpWithReduce) { + auto module = ParseHloString(R"( + HloModule test_module + + add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) + } + + ENTRY entry { + p0 = f32[16,16,16,16]{3,2,1,0} parameter(0) + copy = f32[16,16,16,16]{0,1,2,3} copy(p0) + constant.1 = f32[] constant(0) + ROOT reduce = f32[16] reduce(copy, constant.1), dimensions={0,1,2}, to_apply=add + })") + .ValueOrDie(); + + EXPECT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); +} + +TEST_F(InstructionFusionTest, DoNotFuseLayoutChangingOpWithReduceFusion) { + auto module = ParseHloString(R"( + HloModule test_module + + add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) + } + + fused_reduce { + p0.1 = f32[16,16,16,16]{0,1,2,3} parameter(0) + mul = f32[16,16,16,16]{0,1,2,3} multiply(p0.1, p0.1) + c0.1 = f32[] constant(0) + ROOT root = f32[] reduce(mul, c0.1), dimensions={0,1,2,3}, to_apply=add + } + + ENTRY entry { + p0 = f32[16,16,16,16]{3,2,1,0} parameter(0) + copy = f32[16,16,16,16]{0,1,2,3} copy(p0) + fusion = f32[] fusion(copy), kind=kInput, calls=fused_reduce + ROOT root = (f32[]) tuple(fusion) + })") + .ValueOrDie(); + + EXPECT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); +} + +TEST_F(InstructionFusionTest, FuseLayoutChangingOpWithElementwise) { + auto module = ParseHloString(R"( + HloModule test_module + ENTRY entry { + p0 = f32[16,16,16,16]{3,2,1,0} parameter(0) + copy = f32[16,16,16,16]{0,1,2,3} copy(p0) + ROOT add = f32[16,16,16,16]{0,1,2,3} add(copy, copy) + })") + .ValueOrDie(); + + EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); + + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Fusion()); + EXPECT_THAT(root->fused_expression_root(), op::Add(op::Copy(), op::Copy())); +} + TEST_F(InstructionFusionTest, BitcastIntoAdd) { auto module = ParseHloString(R"( HloModule test_module @@ -365,7 +438,7 @@ static StatusOr FindHloInstruction( } return NotFound( "Computation '%s' does not contain an instruction with op code '%s'.", - computation.name().c_str(), HloOpcodeString(op).c_str()); + computation.name(), HloOpcodeString(op)); } TEST_F(InstructionFusionTest, MultiOutputFusion) { @@ -636,5 +709,44 @@ TEST_F(InstructionFusionTest, AvoidsLargeFusion) { } } +TEST_F(InstructionFusionTest, FuseIntoScatter) { + auto module = ParseHloString(R"( + HloModule test_module + + add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) + } + + ENTRY FuseIntoScatter { + p0 = s32[3,3] parameter(0) + operand = s32[3,3] add(p0, p0) + p1 = s32[2] parameter(1) + indices = s32[2] add(p1, p1) + p2 = s32[2,3] parameter(2) + updates = s32[2,3] add(p2, p2) + scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=add, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 + ROOT add = s32[3,3] add(scatter, scatter) + })") + .ValueOrDie(); + + EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); + + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Add(op::Fusion(), op::Fusion())); + EXPECT_EQ(root->operand(0)->fusion_kind(), + HloInstruction::FusionKind::kInput); + EXPECT_THAT(root->operand(0)->fused_expression_root(), + op::Scatter(op::Add(), op::Add(), op::Add())); +} + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc index f544bcc91976233eff19d97037be989ea0855b86..ec3d8f9405840bb7be97ba5cd5725a4ac68a15a8 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc @@ -20,6 +20,7 @@ limitations under the License. #include "llvm/IR/Module.h" #include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" @@ -128,6 +129,8 @@ const char* const kCudnnConvBackwardInputCallTarget = "__cudnn$convBackwardInput"; const char* const kCudnnConvBackwardFilterCallTarget = "__cudnn$convBackwardFilter"; +const char* const kCudnnConvBiasActivationForwardCallTarget = + "__cudnn$convBiasActivationForward"; bool IsCustomCallToDnnConvolution(const HloInstruction& hlo) { if (hlo.opcode() != HloOpcode::kCustomCall) { @@ -136,7 +139,8 @@ bool IsCustomCallToDnnConvolution(const HloInstruction& hlo) { const auto& target = hlo.custom_call_target(); return target == kCudnnConvForwardCallTarget || target == kCudnnConvBackwardInputCallTarget || - target == kCudnnConvBackwardFilterCallTarget; + target == kCudnnConvBackwardFilterCallTarget || + target == kCudnnConvBiasActivationForwardCallTarget; } bool ImplementedAsLibraryCall(const HloInstruction& hlo) { @@ -144,51 +148,6 @@ bool ImplementedAsLibraryCall(const HloInstruction& hlo) { IsCustomCallToDnnConvolution(hlo); } -static HloInstruction* CreateCudnnConv( - const char* call_target, const Shape& shape, HloInstruction* lhs, - HloInstruction* rhs, const Window& window, - const ConvolutionDimensionNumbers& dnums) { - HloComputation* computation = lhs->parent(); - - // This call returns a tuple of (conv_result, scratch_memory), where - // conv_result is the actual result of the convolution, and scratch_memory is - // temporary memory used by cudnn. - // - // At the moment, we don't know how much scratch memory this conv is going to - // use, so we put u8[0] in this place. Later on another pass will choose - // which conv algorithm to use, and at that point we'll modify the shape of - // this second tuple element. - Shape call_shape = - ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U8, {0})}); - - HloInstruction* custom_call = computation->AddInstruction( - HloInstruction::CreateCustomCall(call_shape, {lhs, rhs}, call_target)); - custom_call->set_window(window); - custom_call->set_convolution_dimension_numbers(dnums); - return custom_call; -} - -HloInstruction* CreateCudnnConvForward( - const Shape& shape, HloInstruction* input, HloInstruction* kernel, - const Window& window, const ConvolutionDimensionNumbers& dnums) { - return CreateCudnnConv(kCudnnConvForwardCallTarget, shape, input, kernel, - window, dnums); -} - -HloInstruction* CreateCudnnConvBackwardInput( - const Shape& shape, HloInstruction* output, HloInstruction* reverse_filter, - const Window& window, const ConvolutionDimensionNumbers& dnums) { - return CreateCudnnConv(kCudnnConvBackwardInputCallTarget, shape, output, - reverse_filter, window, dnums); -} - -HloInstruction* CreateCudnnConvBackwardFilter( - const Shape& shape, HloInstruction* input, HloInstruction* output, - const Window& window, const ConvolutionDimensionNumbers& dnums) { - return CreateCudnnConv(kCudnnConvBackwardFilterCallTarget, shape, input, - output, window, dnums); -} - bool IsReductionToVector(const HloInstruction& reduce) { if (HloOpcode::kReduce != reduce.opcode()) { return false; @@ -216,7 +175,7 @@ bool IsReductionToVector(const HloInstruction& reduce) { // "i32 vprintf(i8* fmt, arguments_type* arguments)" in the driver; see // http://docs.nvidia.com/cuda/ptx-writers-guide-to-interoperability/index.html#system-calls llvm::Value* EmitPrintf(absl::string_view fmt, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, llvm::IRBuilder<>* builder) { std::vector argument_types; for (auto argument : arguments) { @@ -279,5 +238,36 @@ llvm::Value* EmitFullWarpShuffleDown(llvm::Value* value, llvm::Value* offset, value->getType()); } +StatusOr GetCudnnConvKind( + const HloCustomCallInstruction* instr) { + absl::string_view target = instr->custom_call_target(); + if (target == kCudnnConvForwardCallTarget) { + return CudnnConvKind::kForward; + } + if (target == kCudnnConvBackwardInputCallTarget) { + return CudnnConvKind::kBackwardInput; + } + if (target == kCudnnConvBackwardFilterCallTarget) { + return CudnnConvKind::kBackwardFilter; + } + if (target == kCudnnConvBiasActivationForwardCallTarget) { + return CudnnConvKind::kForwardActivation; + } + return InternalError("Unexpected call target: %s", target); +} + +string CudnnConvKindToString(CudnnConvKind kind) { + switch (kind) { + case CudnnConvKind::kForward: + return "forward"; + case CudnnConvKind::kBackwardFilter: + return "backward_filter"; + case CudnnConvKind::kBackwardInput: + return "backward_input"; + case CudnnConvKind::kForwardActivation: + return "forward with activation"; + } +} + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h index a35e250101c0743018b76fffb82e9db591c33de3..f373d4a8393a047aba599b0fae954e98a740161e 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h @@ -21,6 +21,7 @@ limitations under the License. #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Value.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" // TODO(jlebar): Move functions related to cublas/cudnn to a separate file; they // don't belong in "ir_emission_utils". @@ -28,6 +29,33 @@ limitations under the License. namespace xla { namespace gpu { +// Different types of convolutions supported by cudnn. +// +// A way to think about these is that a convolution is defined by three arrays +// -- the "input", the "filter", and the "output" -- and given any two of these, +// we can compute the third. For example, a backward-input convolution takes as +// input a filter and an "output" and produces an "input" such that if one were +// to do a forward convolution of "input" using filter, the result would be +// something with the same shape as "output". +// +// This way of thinking is not correct if you look at the values produced. For +// example, a backward-input convolution is not actually the mathematical +// inverse of a forward convolution. But it's right as far as the shapes and +// "connectivity" (i.e. which elements of the input affect which elements of +// the output) are concerned. +enum class CudnnConvKind { + kForward, // input + filter => output + kBackwardInput, // filter + output => input + kBackwardFilter, // input + output => filter + kForwardActivation, // activation(conv(input, filter) + broadcast(bias) + + // (optionally) side_input) => output +}; + +StatusOr GetCudnnConvKind(const HloCustomCallInstruction* instr); + +// Converts a CudnnConvKind value to a string. +string CudnnConvKindToString(CudnnConvKind kind); + constexpr int64 kWarpSize = 32; // Returns true if `hlo` will be implemented as a call to BLAS gemm. @@ -80,9 +108,9 @@ bool IsCustomCallToDnnBatchNorm(const HloInstruction& hlo); // memory used by cudnn. Callers shouldn't inspect scratch_memory, as its value // is not well-defined. // -// CudnnConvolutionRewriter lowers kConvolution HLOs to these custom calls. +// CudnnConvRewriter lowers kConvolution HLOs to these custom calls. // When it does so, it chooses algorithm -1 and 0 bytes of scratch space. Later -// on in the pipeline, CudnnConvolutionAlgorithmChooser chooses an explicit +// on in the pipeline, CudnnConvAlgorithmChooser chooses an explicit // algorithm for each conv and sets the amount of scratch space needed. // // (Representing the scratch memory as an output may seem strange at first, but @@ -93,6 +121,7 @@ bool IsCustomCallToDnnBatchNorm(const HloInstruction& hlo); extern const char* const kCudnnConvForwardCallTarget; extern const char* const kCudnnConvBackwardInputCallTarget; extern const char* const kCudnnConvBackwardFilterCallTarget; +extern const char* const kCudnnConvBiasActivationForwardCallTarget; // Returns true if `hlo` will be implemented as a call to a cuDNN convolution // routine. @@ -102,23 +131,6 @@ extern const char* const kCudnnConvBackwardFilterCallTarget; // kConvolution opcode. bool IsCustomCallToDnnConvolution(const HloInstruction& hlo); -// Creates a CustomCall for a cudnn forward/backward-input/backward-filter conv. -// Note that these CustomCalls return a tuple (conv_result, scratch_memory). If -// you want just the conv result, you'll need to get-tuple-element the value -// returned by this function. -// -// The created cudnn call will use the default cudnn algorithm and no scratch -// space. -HloInstruction* CreateCudnnConvForward( - const Shape& shape, HloInstruction* input, HloInstruction* kernel, - const Window& window, const ConvolutionDimensionNumbers& dnums); -HloInstruction* CreateCudnnConvBackwardInput( - const Shape& shape, HloInstruction* output, HloInstruction* reverse_filter, - const Window& window, const ConvolutionDimensionNumbers& dnums); -HloInstruction* CreateCudnnConvBackwardFilter( - const Shape& shape, HloInstruction* input, HloInstruction* output, - const Window& window, const ConvolutionDimensionNumbers& dnums); - // Returns true if `hlo` will be implemented as a library call, e.g. cuBLAS gemm // or cuDNN convolution. bool ImplementedAsLibraryCall(const HloInstruction& hlo); @@ -127,7 +139,7 @@ bool IsReductionToVector(const HloInstruction& reduce); // Emits call to "vprintf" with given format and arguments. llvm::Value* EmitPrintf(absl::string_view fmt, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, llvm::IRBuilder<>* builder); // Emits code to shuffle data between threads of a warp. This has the same diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc index 7111b53944770c9dbfcd0611f67b18900bcf1ffb..a3821e077ecf6b1dce1e2c8785fe3a59516db2be 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc @@ -141,7 +141,7 @@ Status IrEmitter::HandleTuple(HloInstruction* tuple) { Status IrEmitter::EmitCallToNestedComputation( const HloComputation& nested_computation, - tensorflow::gtl::ArraySlice operands, llvm::Value* output) { + absl::Span operands, llvm::Value* output) { TF_RET_CHECK(nested_computation.num_parameters() > 0); llvm::Function*& emitted_function = computation_to_ir_function_[&nested_computation]; @@ -156,7 +156,7 @@ Status IrEmitter::EmitCallToNestedComputation( std::vector arguments(operands.begin(), operands.end()); arguments.push_back(output); arguments.push_back(bindings_.GetTempBufferBase()); - b_.CreateCall(emitted_function, arguments); + Call(emitted_function, arguments); return Status::OK(); } @@ -178,7 +178,22 @@ bool IrEmitter::MaybeEmitDirectAtomicOperation( computation.root_instruction()->shape().element_type(); bool is_atomic_integral = element_type == S32 || element_type == U32 || element_type == S64 || element_type == U64; - llvm::Value* source = b_.CreateLoad(source_address, "source"); + llvm::Value* source = Load(source_address, "source"); + + // kCopy of RHS -> atomic store. + if (root_opcode == HloOpcode::kCopy && + (element_type == F32 || is_atomic_integral) && + computation.root_instruction()->operand(0)->opcode() == + HloOpcode::kParameter && + computation.root_instruction()->operand(0)->parameter_number() == 1) { + llvm::StoreInst* store = Store(source, output_address); + store->setAtomic(llvm::AtomicOrdering::Unordered); + // Derive a minimum alignment from the type. The optimizer can increase it + // later. + store->setAlignment(ShapeUtil::ByteSizeOfPrimitiveType(element_type)); + return true; + } + if (root_opcode == HloOpcode::kAdd) { // NVPTX supports atomicAdd on F32 and integer types. if (element_type == F32) { @@ -190,8 +205,8 @@ bool IrEmitter::MaybeEmitDirectAtomicOperation( } if (is_atomic_integral) { // integral + integral - b_.CreateAtomicRMW(llvm::AtomicRMWInst::Add, output_address, source, - llvm::AtomicOrdering::SequentiallyConsistent); + AtomicRMW(llvm::AtomicRMWInst::Add, output_address, source, + llvm::AtomicOrdering::SequentiallyConsistent); return true; } } @@ -202,8 +217,8 @@ bool IrEmitter::MaybeEmitDirectAtomicOperation( auto opcode = primitive_util::IsSignedIntegralType(element_type) ? llvm::AtomicRMWInst::Max : llvm::AtomicRMWInst::UMax; - b_.CreateAtomicRMW(opcode, output_address, source, - llvm::AtomicOrdering::SequentiallyConsistent); + AtomicRMW(opcode, output_address, source, + llvm::AtomicOrdering::SequentiallyConsistent); return true; } @@ -212,8 +227,8 @@ bool IrEmitter::MaybeEmitDirectAtomicOperation( auto opcode = primitive_util::IsSignedIntegralType(element_type) ? llvm::AtomicRMWInst::Min : llvm::AtomicRMWInst::UMin; - b_.CreateAtomicRMW(opcode, output_address, source, - llvm::AtomicOrdering::SequentiallyConsistent); + AtomicRMW(opcode, output_address, source, + llvm::AtomicOrdering::SequentiallyConsistent); return true; } @@ -292,10 +307,10 @@ Status IrEmitter::EmitAtomicOperationUsingCAS(const HloComputation& computation, // cas_old_output_address and cas_new_output_address point to the scratch // memory where we store the old and new values for the repeated atomicCAS // operations. - llvm::Value* cas_old_output_address = b_.CreateAlloca( - atomic_type, /*ArraySize=*/nullptr, "cas_old_output_address"); - llvm::Value* cas_new_output_address = b_.CreateAlloca( - atomic_type, /*ArraySize=*/nullptr, "cas_new_output_address"); + llvm::Value* cas_old_output_address = + Alloca(atomic_type, /*ArraySize=*/nullptr, "cas_old_output_address"); + llvm::Value* cas_new_output_address = + Alloca(atomic_type, /*ArraySize=*/nullptr, "cas_new_output_address"); // Emit preparation code to the preheader. llvm::BasicBlock* loop_preheader_bb = b_.GetInsertBlock(); @@ -309,29 +324,26 @@ Status IrEmitter::EmitAtomicOperationUsingCAS(const HloComputation& computation, CHECK_EQ((element_size % sizeof(char)), 0); llvm::Type* address_int_type = module_->getDataLayout().getIntPtrType(output_address_type); - atomic_memory_address = b_.CreatePtrToInt(output_address, address_int_type); + atomic_memory_address = PtrToInt(output_address, address_int_type); llvm::Value* mask = llvm::ConstantInt::get(address_int_type, 3); - llvm::Value* offset = b_.CreateAnd(atomic_memory_address, mask); + llvm::Value* offset = And(atomic_memory_address, mask); mask = llvm::ConstantInt::get(address_int_type, -4); - atomic_memory_address = b_.CreateAnd(atomic_memory_address, mask); + atomic_memory_address = And(atomic_memory_address, mask); atomic_memory_address = - b_.CreateIntToPtr(atomic_memory_address, atomic_address_type); - binop_output_address = b_.CreateAdd( - b_.CreatePtrToInt(cas_new_output_address, address_int_type), offset); + IntToPtr(atomic_memory_address, atomic_address_type); binop_output_address = - b_.CreateIntToPtr(binop_output_address, element_address_type); + Add(PtrToInt(cas_new_output_address, address_int_type), offset); + binop_output_address = IntToPtr(binop_output_address, element_address_type); } else { - atomic_memory_address = - b_.CreateBitCast(output_address, atomic_address_type); + atomic_memory_address = BitCast(output_address, atomic_address_type); binop_output_address = - b_.CreateBitCast(cas_new_output_address, element_address_type); + BitCast(cas_new_output_address, element_address_type); } // Use the value from the memory that atomicCAS operates on to initialize // cas_old_output. - llvm::Value* cas_old_output = - b_.CreateLoad(atomic_memory_address, "cas_old_output"); - b_.CreateStore(cas_old_output, cas_old_output_address); + llvm::Value* cas_old_output = Load(atomic_memory_address, "cas_old_output"); + Store(cas_old_output, cas_old_output_address); llvm::BasicBlock* loop_exit_bb = loop_preheader_bb->splitBasicBlock( b_.GetInsertPoint(), "atomic_op_loop_exit"); @@ -344,32 +356,29 @@ Status IrEmitter::EmitAtomicOperationUsingCAS(const HloComputation& computation, // Emit the body of the loop that repeatedly invokes atomicCAS. // // Use cas_old_output to initialize cas_new_output. - cas_old_output = b_.CreateLoad(cas_old_output_address, "cas_old_output"); - b_.CreateStore(cas_old_output, cas_new_output_address); + cas_old_output = Load(cas_old_output_address, "cas_old_output"); + Store(cas_old_output, cas_new_output_address); // Emits code to calculate new_output = operation(old_output, source); TF_RETURN_IF_ERROR(EmitCallToNestedComputation( computation, {binop_output_address, source_address}, binop_output_address)); - llvm::Value* cas_new_output = - b_.CreateLoad(cas_new_output_address, "cas_new_output"); + llvm::Value* cas_new_output = Load(cas_new_output_address, "cas_new_output"); // Emit code to perform the atomicCAS operation // (cas_old_output, success) = atomicCAS(memory_address, cas_old_output, // cas_new_output); - llvm::Value* ret_value = b_.CreateAtomicCmpXchg( - atomic_memory_address, cas_old_output, cas_new_output, - llvm::AtomicOrdering::SequentiallyConsistent, - llvm::AtomicOrdering::SequentiallyConsistent); + llvm::Value* ret_value = + AtomicCmpXchg(atomic_memory_address, cas_old_output, cas_new_output, + llvm::AtomicOrdering::SequentiallyConsistent, + llvm::AtomicOrdering::SequentiallyConsistent); // Extract the memory value returned from atomicCAS and store it as // cas_old_output. - b_.CreateStore(b_.CreateExtractValue(ret_value, 0, "cas_old_output"), - cas_old_output_address); + Store(ExtractValue(ret_value, 0, "cas_old_output"), cas_old_output_address); // Extract the success bit returned from atomicCAS and generate a // conditional branch on the success bit. - b_.CreateCondBr(b_.CreateExtractValue(ret_value, 1, "success"), loop_exit_bb, - loop_body_bb); + CondBr(ExtractValue(ret_value, 1, "success"), loop_exit_bb, loop_body_bb); // Set the insertion point to the exit basic block so that the caller of // this method can continue emitting code to the right place. @@ -384,8 +393,8 @@ Status IrEmitter::EmitAtomicOperationForNestedComputation( // TODO(b/30258929): We only accept binary computations so far. return Unimplemented( "We only support atomic functions with exactly two parameters, but " - "computation %s has %lld.", - computation.name().c_str(), computation.num_parameters()); + "computation %s has %d.", + computation.name(), computation.num_parameters()); } if (MaybeEmitDirectAtomicOperation(computation, output_address, @@ -472,10 +481,10 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { if (ShapeUtil::ElementIsComplex(lhs_shape)) { auto value = MultiplyComplex(lhs_value, rhs_value, &b_); result = llvm::ConstantAggregateZero::get(lhs_array.GetElementLlvmType()); - result = b_.CreateInsertValue(result, value.first, {0}); - result = b_.CreateInsertValue(result, value.second, {1}); + result = InsertValue(result, value.first, {0}); + result = InsertValue(result, value.second, {1}); } else { - result = b_.CreateFMul(lhs_value, rhs_value); + result = FMul(lhs_value, rhs_value); } target_array.EmitWriteArrayElement(/*index=*/element_index, result, &b_); return Status::OK(); @@ -486,18 +495,10 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { TF_RET_CHECK(!ShapeUtil::IsScalar(lhs_shape) && !ShapeUtil::IsScalar(rhs_shape)); - // Reduce along the last dimension of the LHS and the second-to-last dimension - // of the RHS. Vectors are a special case where the reduction dimension is 0 - // for both LHS and RHS. This results in a vector dot product producing a - // scalar. - const int64 lhs_reduction_dimension = - ShapeUtil::GetDimensionNumber(lhs_shape, -1); - const int64 rhs_reduction_dimension = - ShapeUtil::Rank(rhs_shape) >= 2 + dnums.lhs_batch_dimensions_size() - ? ShapeUtil::GetDimensionNumber(rhs_shape, -2) - : dnums.lhs_batch_dimensions_size(); - - // Check that the batch dims don't cover the last two dims. + const int64 lhs_reduction_dimension = dnums.lhs_contracting_dimensions(0); + const int64 rhs_reduction_dimension = dnums.rhs_contracting_dimensions(0); + + // Check that the batch dims don't cover the reduction dimensions. for (int64 batch_dim : dnums.lhs_batch_dimensions()) { CHECK_NE(lhs_reduction_dimension, batch_dim); CHECK_NE(rhs_reduction_dimension, batch_dim); @@ -505,7 +506,11 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { // Verify the reduction dimension in the two operands are the same size. TF_RET_CHECK(lhs_shape.dimensions(lhs_reduction_dimension) == - rhs_shape.dimensions(rhs_reduction_dimension)); + rhs_shape.dimensions(rhs_reduction_dimension)) + << "lhs_shape.dimensions(" << lhs_reduction_dimension + << ") = " << lhs_shape.dimensions(lhs_reduction_dimension) + << ", and rhs_shape.dimensions(" << rhs_reduction_dimension + << ") = " << rhs_shape.dimensions(rhs_reduction_dimension); // Create loop nests which loop through the LHS operand dimensions and the RHS // operand dimensions. The reduction dimension of the LHS and RHS are handled @@ -559,21 +564,21 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { &*reduction_loop->GetBodyBasicBlock()->getFirstInsertionPt()); llvm::Value* lhs_element = lhs_array.EmitReadArrayElement(lhs_index, &b_); llvm::Value* rhs_element = rhs_array.EmitReadArrayElement(rhs_index, &b_); - llvm::Value* accum = b_.CreateLoad(accum_address); + llvm::Value* accum = Load(accum_address); llvm::Value* updated_accum; if (ShapeUtil::ElementIsComplex(lhs_shape)) { auto value = MultiplyComplex(lhs_element, rhs_element, &b_); llvm::Value* accum_real = Real(accum, &b_); - llvm::Value* real_sum = b_.CreateFAdd(accum_real, value.first); - updated_accum = b_.CreateInsertValue(accum, real_sum, {0}); + llvm::Value* real_sum = FAdd(accum_real, value.first); + updated_accum = InsertValue(accum, real_sum, {0}); llvm::Value* accum_imag = Imag(accum, &b_); - llvm::Value* imag_sum = b_.CreateFAdd(accum_imag, value.second); - updated_accum = b_.CreateInsertValue(updated_accum, imag_sum, {1}); + llvm::Value* imag_sum = FAdd(accum_imag, value.second); + updated_accum = InsertValue(updated_accum, imag_sum, {1}); } else { - llvm::Value* product = b_.CreateFMul(lhs_element, rhs_element); - updated_accum = b_.CreateFAdd(accum, product); + llvm::Value* product = FMul(lhs_element, rhs_element); + updated_accum = FAdd(accum, product); } - b_.CreateStore(updated_accum, accum_address); + Store(updated_accum, accum_address); // After the reduction loop exits, store the accumulator into the target // address. The index into the target address is the concatenation of the rhs @@ -595,7 +600,7 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { SetToFirstInsertPoint(reduction_loop->GetExitBasicBlock(), &b_); target_array.EmitWriteArrayElement( target_index, - b_.CreateLoad(accum_address), // The value written to the target array. + Load(accum_address), // The value written to the target array. &b_); // Set the IR builder insert point to the exit basic block of the outer most @@ -639,17 +644,16 @@ Status IrEmitter::HandleReduce(HloInstruction* reduce) { } auto arg = reduce->operand(0); auto init_value = reduce->operand(1); - tensorflow::gtl::ArraySlice dimensions(reduce->dimensions()); + absl::Span dimensions(reduce->dimensions()); HloComputation* function = reduce->to_apply(); return EmitTargetElementLoop( *reduce, [=](const llvm_ir::IrArray::Index& index) -> StatusOr { // Initialize an accumulator with init_value. llvm::AllocaInst* accumulator_addr = - b_.CreateAlloca(llvm_ir::PrimitiveTypeToIrType( + Alloca(llvm_ir::PrimitiveTypeToIrType( reduce->shape().element_type(), module_)); - b_.CreateStore(b_.CreateLoad(GetBasePointer(*init_value)), - accumulator_addr); + Store(Load(GetBasePointer(*init_value)), accumulator_addr); // The enclosing loops go over all the target elements. Now we have to // compute the actual target element. For this, we build a new loop nest @@ -686,7 +690,7 @@ Status IrEmitter::HandleReduce(HloInstruction* reduce) { *function, {accumulator_addr, input_address}, accumulator_addr)); SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_); - return b_.CreateLoad(accumulator_addr); + return Load(accumulator_addr); }); } @@ -753,14 +757,9 @@ Status IrEmitter::HandleBatchNormGrad(HloInstruction*) { "to a cudnn CustomCall using CudnnBatchNormRewriter."); } -Status IrEmitter::HandleIota(HloInstruction*) { - // TODO(b/64798317): implement iota on GPU. - return Unimplemented("Iota is not implemented on GPU."); -} - StatusOr IrEmitter::ComputeNestedElement( const HloComputation& computation, - tensorflow::gtl::ArraySlice parameter_elements) { + absl::Span parameter_elements) { llvm::Value* return_buffer = llvm_ir::EmitAllocaAtFunctionEntry( llvm_ir::PrimitiveTypeToIrType( computation.root_instruction()->shape().element_type(), module_), @@ -769,11 +768,26 @@ StatusOr IrEmitter::ComputeNestedElement( for (llvm::Value* parameter_element : parameter_elements) { parameter_buffers.push_back(llvm_ir::EmitAllocaAtFunctionEntry( parameter_element->getType(), "parameter_buffer", &b_)); - b_.CreateStore(parameter_element, parameter_buffers.back()); + Store(parameter_element, parameter_buffers.back()); } TF_RETURN_IF_ERROR(EmitCallToNestedComputation(computation, parameter_buffers, return_buffer)); - return b_.CreateLoad(return_buffer); + return Load(return_buffer); +} + +std::vector IrEmitter::ConstructIrArrayForOutputs( + const HloInstruction& hlo) { + std::vector output_arrays; + if (ShapeUtil::IsTuple(hlo.shape())) { + int64 num_outputs = ShapeUtil::TupleElementCount(hlo.shape()); + output_arrays.reserve(num_outputs); + for (int64 i = 0; i < num_outputs; ++i) { + output_arrays.push_back(GetIrArray(hlo, hlo, {i})); + } + } else { + output_arrays.push_back(GetIrArray(hlo, hlo)); + } + return output_arrays; } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.h b/tensorflow/compiler/xla/service/gpu/ir_emitter.h index 76e069fc41ab1275fc0fb20f86128785c287b6c0..880520148005838cc25a5be9e26c8bc9028a70ce 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.h @@ -23,6 +23,7 @@ limitations under the License. #include #include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Value.h" @@ -36,12 +37,12 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" +#include "tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -64,7 +65,8 @@ namespace gpu { // IrEmitterUnnested, but the code is generated using FusedIrEmitter, which is // not a subclass of gpu::IrEmitter, and in fact is better understood as an IR // generator generator. See comments on that class. -class IrEmitter : public DfsHloVisitorWithDefault { +class IrEmitter : public DfsHloVisitorWithDefault, + public IrBuilderMixin { public: IrEmitter(const IrEmitter&) = delete; IrEmitter& operator=(const IrEmitter&) = delete; @@ -95,10 +97,11 @@ class IrEmitter : public DfsHloVisitorWithDefault { Status HandleBatchNormInference(HloInstruction* batch_norm) override; Status HandleBatchNormTraining(HloInstruction* batch_norm) override; Status HandleBatchNormGrad(HloInstruction* batch_norm) override; - Status HandleIota(HloInstruction* iota) override; Status FinishVisit(HloInstruction* root) override { return Status::OK(); } + llvm::IRBuilder<>* builder() { return &b_; } + protected: // Constructs an IrEmitter with the given IrEmitter context. // ir_emitter_context is owned by the caller and should outlive the IrEmitter @@ -121,6 +124,12 @@ class IrEmitter : public DfsHloVisitorWithDefault { llvm::Value* GetBasePointer(const HloInstruction& inst) const { return bindings_.GetBasePointer(inst); } + + // Generates the IrArray for each output of an hlo instruction and returns + // a vector containing such IrArrays. + std::vector ConstructIrArrayForOutputs( + const HloInstruction& hlo); + // A convenient helper for calling BufferAssignment::GetUniqueSlice. BufferAllocation::Slice GetAllocationSlice( const HloInstruction& hlo, const ShapeIndex& index = {}) const { @@ -140,9 +149,9 @@ class IrEmitter : public DfsHloVisitorWithDefault { // Emits a call in IR to the given nested computation with the given operands // and output. If no IR function has been previously emitted for the // computation, also emits such a function. - Status EmitCallToNestedComputation( - const HloComputation& nested_computation, - tensorflow::gtl::ArraySlice operands, llvm::Value* output); + Status EmitCallToNestedComputation(const HloComputation& nested_computation, + absl::Span operands, + llvm::Value* output); // Emits an atomic operation that implements `nested_computation` in the // sequentially consistent memory model. `output_address` and `source_address` @@ -196,7 +205,7 @@ class IrEmitter : public DfsHloVisitorWithDefault { StatusOr ComputeNestedElement( const HloComputation& computation, - tensorflow::gtl::ArraySlice parameter_elements); + absl::Span parameter_elements); // Emits an atomic operation that implements `nested_computation` in the // sequentially consistent memory model. `output_address` and `source_address` diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc index 5c827e5f9cf3e1c04af444dae338a2ec411ce372..66c65f69758e5a2f4420935279835eaf086fea45 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc @@ -119,21 +119,11 @@ Status IrEmitterNested::EmitTargetElementLoop( // For MOF we give the loop emitter an array for every output it should // generate. if (hlo.IsMultiOutputFusion()) { - const int64 num_elems = ShapeUtil::TupleElementCount(hlo.shape()); - std::vector target_arrays; - target_arrays.reserve(num_elems); - for (int64 i = 0; i != num_elems; ++i) { - target_arrays.push_back(GetIrArray(hlo, hlo, {i})); - } + std::vector target_arrays = + ConstructIrArrayForOutputs(hlo); TF_RETURN_IF_ERROR( llvm_ir::LoopEmitter(element_generator, target_arrays, &b_).EmitLoop()); - - std::vector tuple_operand_ptrs; - tuple_operand_ptrs.reserve(num_elems); - for (const llvm_ir::IrArray& array : target_arrays) { - tuple_operand_ptrs.push_back(array.GetBasePointer()); - } - llvm_ir::EmitTuple(GetIrArray(hlo, hlo), tuple_operand_ptrs, &b_, module_); + llvm_ir::EmitTuple(GetIrArray(hlo, hlo), target_arrays, &b_, module_); return Status::OK(); } return llvm_ir::LoopEmitter(element_generator, GetIrArray(hlo, hlo), &b_) diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 84043689bdcd4c6af165c847a2d188753694cc61..eb8aaaea4f91f552c2f21f104b83924fd604ebfa 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -26,6 +26,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "absl/types/optional.h" +#include "absl/types/span.h" #include "llvm/ADT/StringRef.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Function.h" @@ -33,6 +34,7 @@ limitations under the License. #include "llvm/IR/Instructions.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" +#include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" @@ -42,7 +44,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h" #include "tensorflow/compiler/xla/service/gpu/copy_thunk.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h" -#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h" +#include "tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.h" #include "tensorflow/compiler/xla/service/gpu/fft_thunk.h" #include "tensorflow/compiler/xla/service/gpu/for_thunk.h" #include "tensorflow/compiler/xla/service/gpu/gemm_thunk.h" @@ -60,6 +62,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/thunk.h" #include "tensorflow/compiler/xla/service/gpu/tuple_thunk.h" #include "tensorflow/compiler/xla/service/gpu/while_thunk.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -80,7 +83,6 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/bits.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/logging.h" namespace xla { @@ -94,7 +96,6 @@ using absl::optional; using absl::StrCat; using llvm_ir::IrArray; using llvm_ir::IrName; -using tensorflow::gtl::ArraySlice; // If a dimensions is smaller than this, untiled transposition may be more // efficient. @@ -176,7 +177,7 @@ Status IrEmitterUnnested::Postprocess(HloInstruction* hlo) { llvm::Function* IrEmitterUnnested::BuildKernelPrototype( const HloInstruction& inst, - tensorflow::gtl::ArraySlice args) { + absl::Span args) { // Compute the kernel name. The opcode string may contain "-" which cannot be // in a PTX function name, so sanitize the name before uniquifying it. string kernel_name = ir_emitter_context_->name_uniquer()->GetUniqueName( @@ -465,67 +466,18 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) { if (IsCustomCallToDnnConvolution(*custom_call)) { const auto& assn = ir_emitter_context_->buffer_assignment(); - const auto& lhs_shape = custom_call->operand(0)->shape(); - const auto& rhs_shape = custom_call->operand(1)->shape(); - const auto& conv_result_shape = custom_call->shape().tuple_shapes(0); - auto lhs_slice = GetAllocationSlice(*custom_call->operand(0)); - auto rhs_slice = GetAllocationSlice(*custom_call->operand(1)); + std::vector operand_slices; + operand_slices.reserve(custom_call->operand_count()); + for (const auto* operand : custom_call->operands()) { + operand_slices.push_back(GetAllocationSlice(*operand)); + } auto tuple_result_slice = GetAllocationSlice(*custom_call); auto conv_result_slice = assn.GetUniqueSlice(custom_call, {0}).ValueOrDie(); auto scratch_slice = assn.GetUniqueSlice(custom_call, {1}).ValueOrDie(); - TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig backend_config, - custom_call->backend_config()); - const auto& target = custom_call->custom_call_target(); - std::unique_ptr thunk; - if (target == kCudnnConvForwardCallTarget) { - thunk = absl::make_unique( - CudnnConvKind::kForward, - /*input_buffer=*/lhs_slice, - /*filter_buffer=*/rhs_slice, - /*output_buffer=*/conv_result_slice, - /*tuple_result_buffer=*/tuple_result_slice, - /*scratch_buffer=*/scratch_slice, - /*input_shape=*/lhs_shape, - /*filter_shape=*/rhs_shape, - /*output_shape=*/conv_result_shape, // - custom_call->window(), custom_call->convolution_dimension_numbers(), - backend_config.algorithm(), backend_config.tensor_ops_enabled(), - custom_call); - } else if (target == kCudnnConvBackwardInputCallTarget) { - thunk = absl::make_unique( - CudnnConvKind::kBackwardInput, - /*input_buffer=*/conv_result_slice, - /*filter_buffer=*/rhs_slice, - /*output_buffer=*/lhs_slice, - /*tuple_result_buffer=*/tuple_result_slice, - /*scratch_buffer=*/scratch_slice, - /*input_shape=*/conv_result_shape, - /*filter_shape=*/rhs_shape, - /*output_shape=*/lhs_shape, // - custom_call->window(), custom_call->convolution_dimension_numbers(), - backend_config.algorithm(), backend_config.tensor_ops_enabled(), - custom_call); - } else if (target == kCudnnConvBackwardFilterCallTarget) { - thunk = absl::make_unique( - CudnnConvKind::kBackwardFilter, - /*input_buffer=*/lhs_slice, - /*filter_buffer=*/conv_result_slice, - /*output_buffer=*/rhs_slice, - /*tuple_result_buffer=*/tuple_result_slice, - /*scratch_buffer=*/scratch_slice, - /*input_shape=*/lhs_shape, - /*filter_shape=*/conv_result_shape, - /*output_shape=*/rhs_shape, // - custom_call->window(), custom_call->convolution_dimension_numbers(), - backend_config.algorithm(), backend_config.tensor_ops_enabled(), - custom_call); - } else { - LOG(FATAL) << "Unexpected custom call target: " - << custom_call->custom_call_target(); - } - - thunk_sequence_->emplace_back(std::move(thunk)); + thunk_sequence_->emplace_back(absl::make_unique( + Cast(custom_call), std::move(operand_slices), + conv_result_slice, scratch_slice, tuple_result_slice)); return Status::OK(); } @@ -542,13 +494,68 @@ Status IrEmitterUnnested::HandleFft(HloInstruction* fft) { Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { HloInstruction* root = fusion->fused_expression_root(); - // HandleFusion specializes reduction from a multi-dimensional array to a 1D - // array. The specialized version requires a initializer thunk that - // initializes the output array to the initial value of the reduce. if (HloInstruction::FusionKind::kInput == fusion->fusion_kind()) { switch (root->opcode()) { + case HloOpcode::kScatter: { + std::vector> thunks; + // The initialization from 'operand' is using different loop bounds, so + // emit it in a separate kernel. Treat it like a loop fusion, writing to + // the output buffer. + { + int unroll_factor = ComputeMaxUnrollFactor(fusion); + thunks.push_back(BuildKernelThunk( + fusion, /*implements_whole_instruction=*/false, unroll_factor)); + + std::vector operand_parameter_arrays; + for (HloInstruction* operand : fusion->operands()) { + operand_parameter_arrays.push_back(GetIrArray(*operand, *fusion)); + } + GpuElementalIrEmitter operand_elemental_emitter( + hlo_module_config_, ir_emitter_context_->llvm_module(), &b_, + GetNestedComputer()); + FusedIrEmitter operand_fused_emitter(operand_parameter_arrays, + &operand_elemental_emitter); + TF_RETURN_IF_ERROR( + root->mutable_operand(0)->Accept(&operand_fused_emitter)); + + TF_RETURN_IF_ERROR(EmitTargetElementLoopInThunk( + *fusion, operand_fused_emitter.GetGenerator(root->operand(0)), + static_cast(thunks.back().get()))); + } + + // Now build the actual scatter, reading and writing to the freshly + // filled output buffer. + { + thunks.push_back( + BuildKernelThunk(fusion, + /*implements_whole_instruction=*/false)); + // Spin up a new fused emitter for the scatter kernel and emit it. + std::vector scatter_parameter_arrays; + for (HloInstruction* operand : fusion->operands()) { + scatter_parameter_arrays.push_back(GetIrArray(*operand, *fusion)); + } + GpuElementalIrEmitter scatter_elemental_emitter( + hlo_module_config_, ir_emitter_context_->llvm_module(), &b_, + GetNestedComputer()); + FusedIrEmitter scatter_fused_emitter(scatter_parameter_arrays, + &scatter_elemental_emitter); + TF_RETURN_IF_ERROR(root->Accept(&scatter_fused_emitter)); + TF_RETURN_IF_ERROR(EmitScatter( + thunks.back().get(), root, + /*scatter_indices_gen=*/ + scatter_fused_emitter.GetGenerator(root->operand(1)), + /*updates_gen=*/ + scatter_fused_emitter.GetGenerator(root->operand(2)))); + } + thunk_sequence_->emplace_back( + absl::make_unique(std::move(thunks), fusion)); + return Status::OK(); + } case HloOpcode::kTuple: case HloOpcode::kReduce: { + // HandleFusion specializes reduction from a multi-dimensional array to + // a 1D array. The specialized version requires a initializer thunk that + // initializes the output array to the initial value of the reduce. if (root->opcode() == HloOpcode::kReduce && ShapeUtil::IsTuple(root->shape())) { // TODO(b/112040122): Support variadic reduce. @@ -556,10 +563,10 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { } VLOG(3) << "Emitting fused reduction to vector: " << fusion->ToString(); std::vector> thunks; - ArraySlice output_instructions = + absl::Span output_instructions = root->opcode() == HloOpcode::kTuple ? root->operands() - : ArraySlice(&root, 1); + : absl::Span(&root, 1); // For multi-output fusion emit an initializer for each tuple element. // Otherwise it's sufficient to just initialize the single output. @@ -718,8 +725,7 @@ Status IrEmitterUnnested::HandleCopy(HloInstruction* copy) { Status IrEmitterUnnested::EmitExtraOutputsForReduce( const HloInstruction* reduce, const IrArray::Index& index, - tensorflow::gtl::ArraySlice< - std::pair> + absl::Span> extra_output_gens) { for (int i = 0; i != extra_output_gens.size(); ++i) { const HloInstruction* output = reduce->parent()->FusionInstruction(); @@ -729,19 +735,18 @@ Status IrEmitterUnnested::EmitExtraOutputsForReduce( "extra_output_element_address"); TF_ASSIGN_OR_RETURN(llvm::Value* const extra_output_ir_value, extra_output_gens[i].first(index)); - b_.CreateStore(extra_output_ir_value, extra_output_address); + Store(extra_output_ir_value, extra_output_address); } return Status::OK(); } Status IrEmitterUnnested::EmitReductionToScalar( HloInstruction* reduce, const Shape& input_shape, - tensorflow::gtl::ArraySlice input_gens, - tensorflow::gtl::ArraySlice init_value_gens, - tensorflow::gtl::ArraySlice reducers, - tensorflow::gtl::ArraySlice reduce_output_shapes, - tensorflow::gtl::ArraySlice< - std::pair> + absl::Span input_gens, + absl::Span init_value_gens, + absl::Span reducers, + absl::Span reduce_output_shapes, + absl::Span> extra_output_gens) { // Number of elements processed by a single thread. constexpr int64 kTileSize = 16; @@ -810,17 +815,17 @@ Status IrEmitterUnnested::EmitReductionToScalar( std::vector partial_reduction_result_addresses; for (int i = 0; i != num_reduces; ++i) { llvm::Value* partial_reduction_result_address = - b_.CreateAlloca(element_ir_type, /*ArraySize=*/nullptr, - "partial_reduction_result." + llvm::Twine(i)); + Alloca(element_ir_type, /*ArraySize=*/nullptr, + "partial_reduction_result." + llvm::Twine(i)); TF_ASSIGN_OR_RETURN(llvm::Value* const init_ir_value, init_value_gens[i](IrArray::Index(index_ty))); - b_.CreateStore(init_ir_value, partial_reduction_result_address); + Store(init_ir_value, partial_reduction_result_address); partial_reduction_result_addresses.push_back( partial_reduction_result_address); } llvm::Value* x_in_tiles = tile_index[0]; - x_in_tiles = b_.CreateZExtOrTrunc(x_in_tiles, index_ty); + x_in_tiles = ZExtOrTrunc(x_in_tiles, index_ty); // Emit an inner for-loop that reduces the elements in the tile. auto emit_tile_element_loop = [=](bool tile_in_bounds) -> Status { @@ -832,15 +837,14 @@ Status IrEmitterUnnested::EmitReductionToScalar( // Emit the body of the partial reduction loop. llvm_ir::SetToFirstInsertPoint(tile_element_loop->GetBodyBasicBlock(), &b_); - llvm::Value* x = b_.CreateNSWAdd( - b_.CreateNSWMul(x_in_tiles, index_typed_constant(kTileSize)), - tile_element_loop->GetIndVarValue()); + llvm::Value* x = + NSWAdd(NSWMul(x_in_tiles, index_typed_constant(kTileSize)), + tile_element_loop->GetIndVarValue()); // Unless we know the tile is entirely in bounds, we have to emit a // x-in-bounds check before reading from the input. if (!tile_in_bounds) { llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse( - b_.CreateICmpULT(x, index_typed_constant(num_elems)), "x_in_bounds", - &b_); + ICmpULT(x, index_typed_constant(num_elems)), "x_in_bounds", &b_); // Emit code that reads the input element and accumulates it to // the partial reduction result. @@ -849,11 +853,11 @@ Status IrEmitterUnnested::EmitReductionToScalar( IrArray::Index input_index( /*linear=*/x, input_shape, &b_); - llvm::Value* input_address = b_.CreateAlloca(element_ir_type); + llvm::Value* input_address = Alloca(element_ir_type); for (int i = 0; i != num_reduces; ++i) { TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value, input_gens[i](input_index)); - b_.CreateStore(input_ir_value, input_address); + Store(input_ir_value, input_address); TF_RETURN_IF_ERROR(EmitCallToNestedComputation( *reducers[i], {partial_reduction_result_addresses[i], input_address}, @@ -864,14 +868,14 @@ Status IrEmitterUnnested::EmitReductionToScalar( // x_end = kTileSize + x_in_tiles * kTileSize, i.e., the location that's // immediately beyond the tile. - llvm::Value* x_end = b_.CreateNSWAdd( - index_typed_constant(kTileSize), - b_.CreateNSWMul(x_in_tiles, index_typed_constant(kTileSize))); + llvm::Value* x_end = + NSWAdd(index_typed_constant(kTileSize), + NSWMul(x_in_tiles, index_typed_constant(kTileSize))); // The tile is entirely in bound if all_threads_in_bounds or // x_end <= num_elems. llvm::Value* tile_in_bounds = - b_.CreateOr(b_.CreateICmpULE(x_end, index_typed_constant(num_elems)), - b_.getInt1(all_threads_in_bounds)); + Or(ICmpULE(x_end, index_typed_constant(num_elems)), + b_.getInt1(all_threads_in_bounds)); llvm_ir::LlvmIfData if_tile_in_bounds_data = llvm_ir::EmitIfThenElse(tile_in_bounds, "tile_in_bounds", &b_); llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.true_block, &b_); @@ -892,20 +896,18 @@ Status IrEmitterUnnested::EmitReductionToScalar( for (int shuffle_distance = kWarpSize / 2; shuffle_distance >= 1; shuffle_distance /= 2) { llvm::Value* result_from_other_lane = - b_.CreateAlloca(element_ir_type, nullptr, "result_from_other_lane"); + Alloca(element_ir_type, nullptr, "result_from_other_lane"); for (int i = 0; i != num_reduces; ++i) { - llvm::Value* partial_reduction_result = b_.CreateLoad( - b_.CreateBitCast(partial_reduction_result_addresses[i], - shuffle_ir_type->getPointerTo()), - "partial_reduction_result"); + llvm::Value* partial_reduction_result = + Load(BitCast(partial_reduction_result_addresses[i], + shuffle_ir_type->getPointerTo()), + "partial_reduction_result"); CHECK_EQ(launch_dimensions.threads_per_block() % kWarpSize, 0) << "Requires block size a multiple of the warp size, otherwise we " "will read undefined elements."; - b_.CreateStore( - EmitFullWarpShuffleDown(partial_reduction_result, - b_.getInt32(shuffle_distance), &b_), - b_.CreateBitCast(result_from_other_lane, - shuffle_ir_type->getPointerTo())); + Store(EmitFullWarpShuffleDown(partial_reduction_result, + b_.getInt32(shuffle_distance), &b_), + BitCast(result_from_other_lane, shuffle_ir_type->getPointerTo())); TF_RETURN_IF_ERROR(EmitCallToNestedComputation( *reducers[i], {partial_reduction_result_addresses[i], result_from_other_lane}, @@ -920,10 +922,9 @@ Status IrEmitterUnnested::EmitReductionToScalar( // lane 0 (which holds the partially accumulated result for its warp) to the // output element. llvm::Value* lane_id = - b_.CreateURem(x_in_tiles, index_typed_constant(kWarpSize), "lane_id"); + URem(x_in_tiles, index_typed_constant(kWarpSize), "lane_id"); llvm_ir::LlvmIfData if_lane_id_is_zero_data = llvm_ir::EmitIfThenElse( - b_.CreateICmpEQ(lane_id, index_typed_constant(0)), "lane_id_is_zero", - &b_); + ICmpEQ(lane_id, index_typed_constant(0)), "lane_id_is_zero", &b_); llvm_ir::SetToFirstInsertPoint(if_lane_id_is_zero_data.true_block, &b_); for (int i = 0; i != num_reduces; ++i) { @@ -955,12 +956,11 @@ Status IrEmitterUnnested::EmitReductionToScalar( Status IrEmitterUnnested::EmitColumnReduction( int64 height, int64 width, HloInstruction* reduce, const Shape& input_shape, - tensorflow::gtl::ArraySlice input_gens, - tensorflow::gtl::ArraySlice init_value_gens, - tensorflow::gtl::ArraySlice reducers, - tensorflow::gtl::ArraySlice reduce_output_shapes, - tensorflow::gtl::ArraySlice< - std::pair> + absl::Span input_gens, + absl::Span init_value_gens, + absl::Span reducers, + absl::Span reduce_output_shapes, + absl::Span> extra_output_gens) { // Divide the input matrix into tiles of size KxL. For example, when the // input matrix is 4x4, K=2, and L=1 the tiled matrix looks like @@ -1043,12 +1043,12 @@ Status IrEmitterUnnested::EmitColumnReduction( for (int i = 0; i != num_reduces; ++i) { for (int x_offset = 0; x_offset < kTileWidth; ++x_offset) { llvm::Value* partial_reduction_result_address = - b_.CreateAlloca(element_ir_type, /*ArraySize=*/nullptr, - "partial_reduction_result." + - llvm::Twine(i * kTileWidth + x_offset)); + Alloca(element_ir_type, /*ArraySize=*/nullptr, + "partial_reduction_result." + + llvm::Twine(i * kTileWidth + x_offset)); TF_ASSIGN_OR_RETURN(llvm::Value* const init_ir_value, init_value_gens[i](IrArray::Index(index_ty))); - b_.CreateStore(init_ir_value, partial_reduction_result_address); + Store(init_ir_value, partial_reduction_result_address); partial_reduction_result_addresses.push_back( partial_reduction_result_address); } @@ -1059,8 +1059,8 @@ Status IrEmitterUnnested::EmitColumnReduction( llvm::Value* y_in_tiles = tile_index[0]; llvm::Value* x_in_tiles = tile_index[1]; - y_in_tiles = b_.CreateZExtOrTrunc(y_in_tiles, index_ty); - x_in_tiles = b_.CreateZExtOrTrunc(x_in_tiles, index_ty); + y_in_tiles = ZExtOrTrunc(y_in_tiles, index_ty); + x_in_tiles = ZExtOrTrunc(x_in_tiles, index_ty); auto emit_tile_element_loop = [=](bool tile_in_y_bounds, bool tile_in_x_bounds) -> Status { @@ -1072,34 +1072,32 @@ Status IrEmitterUnnested::EmitColumnReduction( // Emit the body of the partial reduction loop. llvm_ir::SetToFirstInsertPoint(tile_element_loop->GetBodyBasicBlock(), &b_); - llvm::Value* y = b_.CreateNSWAdd( - b_.CreateNSWMul(y_in_tiles, index_typed_constant(kTileHeight)), - tile_element_loop->GetIndVarValue()); + llvm::Value* y = + NSWAdd(NSWMul(y_in_tiles, index_typed_constant(kTileHeight)), + tile_element_loop->GetIndVarValue()); // Unless we know that y is in bounds, we have to emit a check before // reading from the input. if (!tile_in_y_bounds) { llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse( - b_.CreateICmpULT(y, index_typed_constant(height)), "y_in_bounds", - &b_); + ICmpULT(y, index_typed_constant(height)), "y_in_bounds", &b_); // Emit code that reads the input element and accumulates it to // the partial reduction result. llvm_ir::SetToFirstInsertPoint(if_data.true_block, &b_); } for (int x_offset = 0; x_offset < kTileWidth; ++x_offset) { - llvm::Value* x = b_.CreateNSWAdd( - b_.CreateNSWMul(x_in_tiles, index_typed_constant(kTileWidth)), - index_typed_constant(x_offset)); + llvm::Value* x = + NSWAdd(NSWMul(x_in_tiles, index_typed_constant(kTileWidth)), + index_typed_constant(x_offset)); // Unless we know that x is in bounds, we have to emit a check before // reading from the input. if (!tile_in_x_bounds) { llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse( - b_.CreateICmpULT(x, index_typed_constant(width)), "x_in_bounds", - &b_); + ICmpULT(x, index_typed_constant(width)), "x_in_bounds", &b_); llvm_ir::SetToFirstInsertPoint(if_data.true_block, &b_); } - llvm::Value* input_address = b_.CreateAlloca(element_ir_type); + llvm::Value* input_address = Alloca(element_ir_type); // {y,x} is an index to input_matrix_shape [height,width]. We need to // convert that to an index to input_shape (the shape of the operand of // "reduce"). This conversion is composed of a transposition from @@ -1126,7 +1124,7 @@ Status IrEmitterUnnested::EmitColumnReduction( for (int i = 0; i != num_reduces; ++i) { TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value, input_gens[i](input_index)); - b_.CreateStore(input_ir_value, input_address); + Store(input_ir_value, input_address); TF_RETURN_IF_ERROR(EmitCallToNestedComputation( *reducers[i], {partial_reduction_result_addresses[i * kTileWidth + x_offset], @@ -1141,20 +1139,20 @@ Status IrEmitterUnnested::EmitColumnReduction( // y_end = kTileHeight + y_in_tiles * kTileHeight, i.e., the y location // that's immediately beyond the tile. - llvm::Value* y_end = b_.CreateNSWAdd( - index_typed_constant(kTileHeight), - b_.CreateNSWMul(y_in_tiles, index_typed_constant(kTileHeight))); + llvm::Value* y_end = + NSWAdd(index_typed_constant(kTileHeight), + NSWMul(y_in_tiles, index_typed_constant(kTileHeight))); // x_end = kTileWidth + x_in_tiles * kTileWidth, i.e., the x location // that's immediately beyond the tile. - llvm::Value* x_end = b_.CreateNSWAdd( - index_typed_constant(kTileWidth), - b_.CreateNSWMul(x_in_tiles, index_typed_constant(kTileWidth))); + llvm::Value* x_end = + NSWAdd(index_typed_constant(kTileWidth), + NSWMul(x_in_tiles, index_typed_constant(kTileWidth))); llvm::Value* tile_in_y_bounds = - b_.CreateOr(b_.CreateICmpULE(y_end, index_typed_constant(height)), - b_.getInt1(height % kTileHeight == 0)); + Or(ICmpULE(y_end, index_typed_constant(height)), + b_.getInt1(height % kTileHeight == 0)); llvm::Value* tile_in_x_bounds = - b_.CreateOr(b_.CreateICmpULE(x_end, index_typed_constant(width)), - b_.getInt1(width % kTileWidth == 0)); + Or(ICmpULE(x_end, index_typed_constant(width)), + b_.getInt1(width % kTileWidth == 0)); // The tile is in y bounds if "height" is a multiple of kTileHeight or // y_end <= height. llvm_ir::LlvmIfData if_tile_in_y_bounds_data = @@ -1188,9 +1186,9 @@ Status IrEmitterUnnested::EmitColumnReduction( reduce->IsFused() ? reduce->parent()->FusionInstruction() : reduce; for (int i = 0; i != num_reduces; ++i) { for (int x_offset = 0; x_offset < kTileWidth; ++x_offset) { - llvm::Value* x = b_.CreateNSWAdd( - b_.CreateNSWMul(x_in_tiles, index_typed_constant(kTileWidth)), - index_typed_constant(x_offset)); + llvm::Value* x = + NSWAdd(NSWMul(x_in_tiles, index_typed_constant(kTileWidth)), + index_typed_constant(x_offset)); llvm::Value* output_address = GetIrArray(*output, *output, reduce_output_shapes[i]) .EmitArrayElementAddress( @@ -1246,12 +1244,11 @@ static std::pair ComputeTilingSchemeForReduction( Status IrEmitterUnnested::EmitRowReduction( int64 depth, int64 height, int64 width, HloInstruction* reduce, const Shape& input_shape, - tensorflow::gtl::ArraySlice input_gens, - tensorflow::gtl::ArraySlice init_value_gens, - tensorflow::gtl::ArraySlice reducers, - tensorflow::gtl::ArraySlice reduce_output_shapes, - tensorflow::gtl::ArraySlice< - std::pair> + absl::Span input_gens, + absl::Span init_value_gens, + absl::Span reducers, + absl::Span reduce_output_shapes, + absl::Span> extra_output_gens) { // A naive algorithm is: // 1. Divide the x dimension of the input tensor into tiles of size 1x1xX. @@ -1379,11 +1376,11 @@ Status IrEmitterUnnested::EmitRowReduction( std::vector partial_reduction_result_addresses; for (int i = 0; i != num_reduces; ++i) { llvm::Value* partial_reduction_result_address = - b_.CreateAlloca(element_ir_type, /*ArraySize=*/nullptr, - "partial_reduction_result." + llvm::Twine(i)); + Alloca(element_ir_type, /*ArraySize=*/nullptr, + "partial_reduction_result." + llvm::Twine(i)); TF_ASSIGN_OR_RETURN(llvm::Value* const init_ir_value, init_value_gens[i](IrArray::Index(index_ty))); - b_.CreateStore(init_ir_value, partial_reduction_result_address); + Store(init_ir_value, partial_reduction_result_address); partial_reduction_result_addresses.push_back( partial_reduction_result_address); } @@ -1392,22 +1389,20 @@ Status IrEmitterUnnested::EmitRowReduction( llvm::Value* y = tile_index[1]; llvm::Value* x_tile = tile_index[2]; - x_tile = b_.CreateZExtOrTrunc(x_tile, index_ty); + x_tile = ZExtOrTrunc(x_tile, index_ty); llvm::Value* warp_id = - b_.CreateUDiv(x_tile, index_typed_constant(kWarpSize), "warp_id"); + UDiv(x_tile, index_typed_constant(kWarpSize), "warp_id"); llvm::Value* lane_id = - b_.CreateURem(x_tile, index_typed_constant(kWarpSize), "lane_id"); + URem(x_tile, index_typed_constant(kWarpSize), "lane_id"); // The x-location of the last element in this z-x-tile. // last_x = lane_id + warpSize * (x_tile_size - 1 + warp_id * x_tile_size); - llvm::Value* last_x = b_.CreateNSWAdd( + llvm::Value* last_x = NSWAdd( lane_id, - b_.CreateNSWMul( - index_typed_constant(kWarpSize), - b_.CreateNSWAdd( - index_typed_constant(x_tile_size - 1), - b_.CreateNSWMul(warp_id, index_typed_constant(x_tile_size))))); + NSWMul(index_typed_constant(kWarpSize), + NSWAdd(index_typed_constant(x_tile_size - 1), + NSWMul(warp_id, index_typed_constant(x_tile_size))))); KernelSupportLibrary ksl( &b_, @@ -1419,9 +1414,8 @@ Status IrEmitterUnnested::EmitRowReduction( auto emit_z_x_tile_element_loop = [&](bool x_tile_in_bounds, int64 x_tile_loop_bound) -> Status { auto emit_z_tile_element_loop = [&](llvm::Value* z_indvar) -> Status { - llvm::Value* z = b_.CreateNSWAdd( - z_indvar, - b_.CreateNSWMul(index_typed_constant(z_tile_size), z_tile)); + llvm::Value* z = + NSWAdd(z_indvar, NSWMul(index_typed_constant(z_tile_size), z_tile)); TF_RETURN_IF_ERROR(ksl.For( "x_tile", /*start=*/index_typed_constant(0), @@ -1429,22 +1423,20 @@ Status IrEmitterUnnested::EmitRowReduction( /*step=*/1, [&](llvm::Value* x_indvar) -> Status { // x = lane_id + // warpSize * (element_id_in_x_tile + warp_id * x_tile_size); - llvm::Value* x = b_.CreateNSWAdd( + llvm::Value* x = NSWAdd( lane_id, - b_.CreateNSWMul( - index_typed_constant(kWarpSize), - b_.CreateNSWAdd( - x_indvar, b_.CreateNSWMul( - warp_id, llvm::ConstantInt::get( - index_ty, x_tile_size))))); + NSWMul(index_typed_constant(kWarpSize), + NSWAdd(x_indvar, + NSWMul(warp_id, llvm::ConstantInt::get( + index_ty, x_tile_size))))); // Unless we know the x-tile is entirely in bounds, we have to // emit a x-in-bounds check before reading from the input. if (!x_tile_in_bounds) { llvm_ir::LlvmIfData if_x_in_bounds_data = llvm_ir::EmitIfThenElse( - b_.CreateICmpULT(x, index_typed_constant(width)), - "x_in_bounds", &b_); + ICmpULT(x, index_typed_constant(width)), "x_in_bounds", + &b_); // Points b_ to the then-block. llvm_ir::SetToFirstInsertPoint(if_x_in_bounds_data.true_block, &b_); @@ -1452,7 +1444,7 @@ Status IrEmitterUnnested::EmitRowReduction( // Emit code that reads the input element and accumulates it // to the partial reduction result. - llvm::Value* input_address = b_.CreateAlloca(element_ir_type); + llvm::Value* input_address = Alloca(element_ir_type); { // {z,y,x} is an index to input_3d_tensor_shape // [depth,height,width]. We need to convert that to an index @@ -1483,7 +1475,7 @@ Status IrEmitterUnnested::EmitRowReduction( for (int i = 0; i != num_reduces; ++i) { TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value, input_gens[i](input_index)); - b_.CreateStore(input_ir_value, input_address); + Store(input_ir_value, input_address); TF_RETURN_IF_ERROR(EmitCallToNestedComputation( *reducers[i], {partial_reduction_result_addresses[i], input_address}, @@ -1503,8 +1495,8 @@ Status IrEmitterUnnested::EmitRowReduction( }; llvm::Value* tile_in_bounds = - b_.CreateOr(b_.getInt1(width % (x_tile_size * kWarpSize) == 0), - b_.CreateICmpULT(last_x, index_typed_constant(width))); + Or(b_.getInt1(width % (x_tile_size * kWarpSize) == 0), + ICmpULT(last_x, index_typed_constant(width))); TF_RETURN_IF_ERROR( ksl.If(tile_in_bounds, @@ -1532,20 +1524,18 @@ Status IrEmitterUnnested::EmitRowReduction( for (int shuffle_distance = 16; shuffle_distance >= 1; shuffle_distance /= 2) { llvm::Value* result_from_other_lane = - b_.CreateAlloca(element_ir_type, nullptr, "result_from_other_lane"); + Alloca(element_ir_type, nullptr, "result_from_other_lane"); for (int i = 0; i != num_reduces; ++i) { - llvm::Value* partial_reduction_result = b_.CreateLoad( - b_.CreateBitCast(partial_reduction_result_addresses[i], - shuffle_ir_type->getPointerTo()), - "partial_reduction_result"); + llvm::Value* partial_reduction_result = + Load(BitCast(partial_reduction_result_addresses[i], + shuffle_ir_type->getPointerTo()), + "partial_reduction_result"); CHECK_EQ(launch_dimensions.threads_per_block() % kWarpSize, 0) << "Requires block size a multiple of the warp size, otherwise we " "will read undefined elements."; - b_.CreateStore( - EmitFullWarpShuffleDown(partial_reduction_result, - b_.getInt32(shuffle_distance), &b_), - b_.CreateBitCast(result_from_other_lane, - shuffle_ir_type->getPointerTo())); + Store(EmitFullWarpShuffleDown(partial_reduction_result, + b_.getInt32(shuffle_distance), &b_), + BitCast(result_from_other_lane, shuffle_ir_type->getPointerTo())); TF_RETURN_IF_ERROR(EmitCallToNestedComputation( *reducers[i], {partial_reduction_result_addresses[i], result_from_other_lane}, @@ -1560,8 +1550,7 @@ Status IrEmitterUnnested::EmitRowReduction( // lane 0 (which holds the partially accumulated result for its warp) to the // output element. llvm_ir::LlvmIfData if_lane_id_is_zero_data = llvm_ir::EmitIfThenElse( - b_.CreateICmpEQ(lane_id, index_typed_constant(0)), "lane_id_is_zero", - &b_); + ICmpEQ(lane_id, index_typed_constant(0)), "lane_id_is_zero", &b_); llvm_ir::SetToFirstInsertPoint(if_lane_id_is_zero_data.true_block, &b_); for (int i = 0; i != num_reduces; ++i) { llvm::Value* output_address = @@ -1607,13 +1596,12 @@ Status IrEmitterUnnested::EmitRowReduction( // elementwise. Status IrEmitterUnnested::EmitReductionToVector( HloInstruction* reduce, const Shape& input_shape, - tensorflow::gtl::ArraySlice input_gens, - tensorflow::gtl::ArraySlice init_value_gens, - tensorflow::gtl::ArraySlice dimensions_to_reduce, - tensorflow::gtl::ArraySlice reducers, - tensorflow::gtl::ArraySlice reduce_output_shapes, - tensorflow::gtl::ArraySlice< - std::pair> + absl::Span input_gens, + absl::Span init_value_gens, + absl::Span dimensions_to_reduce, + absl::Span reducers, + absl::Span reduce_output_shapes, + absl::Span> extra_output_gens) { // This emission requires "reduce" to have an input layout. It is either set // by LayoutAssignment (for a top-level kReduce) or by InstructionFusion (for @@ -1708,7 +1696,7 @@ Status IrEmitterUnnested::HandleReduce(HloInstruction* reduce) { } auto input = reduce->operand(0); auto init_value = reduce->operand(1); - tensorflow::gtl::ArraySlice dimensions_to_reduce(reduce->dimensions()); + absl::Span dimensions_to_reduce(reduce->dimensions()); HloComputation* reducer = reduce->to_apply(); // HandleReduce specializes reduction from a multi-dimensional array to a 1D // array. The specialized version requires an initializer thunk that @@ -1740,6 +1728,14 @@ Status IrEmitterUnnested::HandleReduce(HloInstruction* reduce) { } Status IrEmitterUnnested::HandleTuple(HloInstruction* tuple) { + // For the root node of the entry computation we can elide writing the tuple + // buffer. We can always figure out the contents of the tuples from buffer + // assignment because we insert copies to ensure non-ambiguous output buffers. + // GpuExecutable never reads the tuple buffer. + if (tuple == + tuple->parent()->parent()->entry_computation()->root_instruction()) { + return Status::OK(); + } bool all_tuple_elements_have_buffer = absl::c_all_of(tuple->operands(), [&](HloInstruction* tuple_element) { return ir_emitter_context_->buffer_assignment() @@ -1845,7 +1841,7 @@ Status IrEmitterUnnested::HandleSelectAndScatter( &b_); llvm::Value* initialized_flag_address = llvm_ir::EmitAllocaAtFunctionEntry( b_.getInt1Ty(), "initialized_flag_address", &b_); - b_.CreateStore(b_.getInt1(false), initialized_flag_address); + Store(b_.getInt1(false), initialized_flag_address); // Create the inner loop to iterate over the window. llvm_ir::ForLoopNest window_loops(IrName(select_and_scatter, "inner"), &b_, @@ -1866,15 +1862,15 @@ Status IrEmitterUnnested::HandleSelectAndScatter( IrArray::Index operand_index(index_type, source_index.size()); llvm::Value* in_bounds_condition = b_.getInt1(true); for (int64 i = 0; i < rank; ++i) { - llvm::Value* strided_index = b_.CreateNSWMul( + llvm::Value* strided_index = NSWMul( source_index[i], index_typed_constant(window.dimensions(i).stride())); - operand_index[i] = b_.CreateNSWSub( - b_.CreateNSWAdd(strided_index, window_index[i]), - index_typed_constant(window.dimensions(i).padding_low())); - llvm::Value* index_condition = b_.CreateICmpULT( + operand_index[i] = + NSWSub(NSWAdd(strided_index, window_index[i]), + index_typed_constant(window.dimensions(i).padding_low())); + llvm::Value* index_condition = ICmpULT( operand_index[i], index_typed_constant(ShapeUtil::GetDimension(operand->shape(), i))); - in_bounds_condition = b_.CreateAnd(in_bounds_condition, index_condition); + in_bounds_condition = And(in_bounds_condition, index_condition); } CHECK(in_bounds_condition != nullptr); @@ -1884,7 +1880,7 @@ Status IrEmitterUnnested::HandleSelectAndScatter( llvm_ir::EmitIfThenElse(in_bounds_condition, "in-bounds", &b_); llvm_ir::SetToFirstInsertPoint(if_in_bounds.true_block, &b_); llvm_ir::LlvmIfData if_initialized = llvm_ir::EmitIfThenElse( - b_.CreateLoad(initialized_flag_address), "initialized", &b_); + Load(initialized_flag_address), "initialized", &b_); // If the initialized_flag is false, initialize the selected value and index // with the currently visiting operand. @@ -1892,16 +1888,16 @@ Status IrEmitterUnnested::HandleSelectAndScatter( const auto save_operand_index = [&](const IrArray::Index& operand_index) { for (int64 i = 0; i < rank; ++i) { llvm::Value* selected_index_address_slot = - b_.CreateInBoundsGEP(selected_index_address, {b_.getInt32(i)}); - b_.CreateStore(operand_index[i], selected_index_address_slot); + InBoundsGEP(selected_index_address, {b_.getInt32(i)}); + Store(operand_index[i], selected_index_address_slot); } }; IrArray operand_array = GetIrArray(*operand, *select_and_scatter); llvm::Value* operand_data = operand_array.EmitReadArrayElement(operand_index, &b_); - b_.CreateStore(operand_data, selected_value_address); + Store(operand_data, selected_value_address); save_operand_index(operand_index); - b_.CreateStore(b_.getInt1(true), initialized_flag_address); + Store(b_.getInt1(true), initialized_flag_address); // If the initialized_flag is true, call the `select` function to // potentially update the selected value and index with the currently @@ -1917,11 +1913,11 @@ Status IrEmitterUnnested::HandleSelectAndScatter( TF_RETURN_IF_ERROR(EmitCallToNestedComputation( *select_and_scatter->select(), {selected_value_address, operand_address}, select_return_buffer)); - llvm::Value* result = b_.CreateLoad(select_return_buffer); + llvm::Value* result = Load(select_return_buffer); // If the 'select' function returns false, update the selected value and the // index to the currently visiting operand. - llvm::Value* cond = b_.CreateICmpNE( + llvm::Value* cond = ICmpNE( result, llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType( PRED, ir_emitter_context_->llvm_module()), @@ -1930,7 +1926,7 @@ Status IrEmitterUnnested::HandleSelectAndScatter( llvm_ir::LlvmIfData if_select_lhs = llvm_ir::EmitIfThenElse(cond, "if-select-lhs", &b_); llvm_ir::SetToFirstInsertPoint(if_select_lhs.false_block, &b_); - b_.CreateStore(b_.CreateLoad(operand_address), selected_value_address); + Store(Load(operand_address), selected_value_address); save_operand_index(operand_index); // After iterating over the window elements, scatter the source element to @@ -1942,8 +1938,8 @@ Status IrEmitterUnnested::HandleSelectAndScatter( IrArray::Index selected_index(operand_index.GetType()); for (int64 i = 0; i < rank; ++i) { llvm::Value* selected_index_address_slot = - b_.CreateInBoundsGEP(selected_index_address, {b_.getInt32(i)}); - selected_index.push_back(b_.CreateLoad(selected_index_address_slot)); + InBoundsGEP(selected_index_address, {b_.getInt32(i)}); + selected_index.push_back(Load(selected_index_address_slot)); } llvm::Value* source_value_address = GetIrArray(*source, *select_and_scatter) @@ -2026,6 +2022,178 @@ Status IrEmitterUnnested::HandleRng(HloInstruction* rng) { return Status::OK(); } +Status IrEmitterUnnested::HandleScatter(HloInstruction* scatter) { + const HloInstruction* operand = scatter->operand(0); + const HloInstruction* scatter_indices = scatter->operand(1); + const HloInstruction* updates = scatter->operand(2); + + std::vector> thunks; + + // Copy the operand into the output if it's not the same buffer already. + auto operand_buffer = GetAllocationSlice(*operand); + auto destination_buffer = GetAllocationSlice(*scatter); + if (operand_buffer != destination_buffer) { + thunks.push_back(absl::make_unique( + /*source_address=*/operand_buffer, + /*destination_buffer=*/destination_buffer, + /*mem_size=*/ShapeUtil::ByteSizeOf(operand->shape()), scatter)); + } + + thunks.push_back( + BuildKernelThunk(scatter, + /*implements_whole_instruction=*/thunks.empty())); + + TF_RETURN_IF_ERROR( + EmitScatter(thunks.back().get(), scatter, + /*scatter_indices_gen=*/ + [=](const IrArray::Index& index) { + return GetIrArray(*scatter_indices, *scatter) + .EmitReadArrayElement(index, &b_, "scatter_index"); + }, + /*updates_gen=*/ + [=](const IrArray::Index& index) { + return GetIrArray(*updates, *scatter) + .EmitReadArrayElement(index, &b_, "update"); + })); + + // Elide the sequential thunk if there's no copy. + if (thunks.size() == 1) { + thunk_sequence_->push_back(std::move(thunks[0])); + } else { + thunk_sequence_->emplace_back( + absl::make_unique(std::move(thunks), scatter)); + } + return Status::OK(); +} + +Status IrEmitterUnnested::EmitScatter( + Thunk* thunk, HloInstruction* scatter, + const llvm_ir::ElementGenerator& scatter_indices_gen, + const llvm_ir::ElementGenerator& updates_gen) { + const HloInstruction* operand = scatter->operand(0); + const HloInstruction* scatter_indices = scatter->operand(1); + const HloInstruction* updates = scatter->operand(2); + const ScatterDimensionNumbers& dim_numbers = + scatter->scatter_dimension_numbers(); + CHECK(ShapeUtil::Equal(scatter->shape(), operand->shape())); + + auto loop_body_emitter = [&](const IrArray::Index& index) -> Status { + std::vector raw_window_multidim; + std::vector input_scatter_multidim; + std::vector raw_window_bounds; + + // Partition the index into window indices and scatter indices. + for (int64 i = 0, e = index.size(); i != e; ++i) { + // For window indices also remember the window size, this comes in handy + // later. + if (absl::c_binary_search(dim_numbers.update_window_dims(), i)) { + raw_window_multidim.push_back(index[i]); + raw_window_bounds.push_back(updates->shape().dimensions(i)); + } else { + input_scatter_multidim.push_back(index[i]); + } + } + DCHECK_EQ(raw_window_multidim.size(), + dim_numbers.update_window_dims_size()); + + // Apply inserted_window_dims to the window dimensions. + int64 raw_window_multidim_idx = 0; + std::vector input_window_multidim; + std::vector input_window_bounds; + for (int64 i = 0, e = ShapeUtil::Rank(operand->shape()); i != e; ++i) { + if (absl::c_binary_search(dim_numbers.inserted_window_dims(), i)) { + input_window_bounds.push_back(1); // Trivial dimension. + input_window_multidim.push_back(index.GetConstantWithIndexType(0)); + } else { + input_window_bounds.push_back( + raw_window_bounds[raw_window_multidim_idx]); + input_window_multidim.push_back( + raw_window_multidim[raw_window_multidim_idx]); + ++raw_window_multidim_idx; + } + } + DCHECK_EQ(input_window_multidim.size(), ShapeUtil::Rank(operand->shape())); + + // Insert a 1 dimension at the end if index_vector_dim requests one. + Shape scatter_indices_shape = scatter_indices->shape(); + if (dim_numbers.index_vector_dim() == + ShapeUtil::Rank(scatter_indices_shape)) { + scatter_indices_shape.add_dimensions(1); + scatter_indices_shape.mutable_layout()->add_minor_to_major( + dim_numbers.index_vector_dim()); + } + + // Now load the indices corresponding to the current window from + // scatter_indices. + llvm_ir::IrArray::Index raw_scatter_index_index(input_scatter_multidim, + index.GetType()); + raw_scatter_index_index.InsertAt(dim_numbers.index_vector_dim(), nullptr); + llvm::Value* is_in_bounds = b_.getTrue(); + for (int64 i = 0, e = dim_numbers.scatter_dims_to_operand_dims_size(); + i != e; ++i) { + // Our index is stored along index_vector_dim, insert that into the lookup + // index into scatter_indices. + raw_scatter_index_index[dim_numbers.index_vector_dim()] = + raw_scatter_index_index.GetConstantWithIndexType(i); + + int64 operand_dim = dim_numbers.scatter_dims_to_operand_dims(i); + TF_ASSIGN_OR_RETURN( + llvm::Value* const loaded_scatter_index, + scatter_indices_gen(raw_scatter_index_index.SourceIndexOfReshape( + scatter_indices_shape, scatter_indices->shape(), &b_))); + // And add the index to our window index. This yields the output index. + llvm::Value* casted_scatter_index = + IntCast(loaded_scatter_index, index.GetType(), + /*isSigned=*/true); + llvm::Value* dim_offset = + Add(input_window_multidim[operand_dim], casted_scatter_index); + input_window_multidim[operand_dim] = dim_offset; + + // Also do the bounds check now. + int64 max_index = operand->shape().dimensions(operand_dim) - + input_window_bounds[operand_dim] + 1; + // is_in_bounds = index >= 0 && index < dim_size-window_size+1 + // --> index u< dim_size-window_size+1 + is_in_bounds = + And(is_in_bounds, ICmpULT(casted_scatter_index, + index.GetConstantWithIndexType(max_index))); + } + + llvm_ir::LlvmIfData if_window_in_bounds_data = llvm_ir::EmitIfThenElse( + is_in_bounds, "scatter.in_bounds", &b_, /*emit_else=*/false); + llvm_ir::SetToFirstInsertPoint(if_window_in_bounds_data.true_block, &b_); + // All done, now just read from the calculated input from the window, and do + // an atomic store to the calculated location in the output. + llvm_ir::IrArray::Index input_window_index(input_window_multidim, + index.GetType()); + HloInstruction* output_hlo = + scatter->IsFused() ? scatter->parent()->FusionInstruction() : scatter; + llvm::Value* output_address = + GetIrArray(*output_hlo, *output_hlo) + .EmitArrayElementAddress(input_window_index, &b_); + llvm::Value* input_address = Alloca(llvm_ir::PrimitiveTypeToIrType( + updates->shape().element_type(), module_)); + TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value, updates_gen(index)); + Store(input_ir_value, input_address); + return EmitAtomicOperationForNestedComputation( + *scatter->to_apply(), output_address, input_address); + }; + + // Launch a kernel that reads every element in the updates tensor. We could + // also do one kernel per window instead if bounds checks turn out to be a + // bottleneck. + LaunchDimensions launch_dimensions = CalculateLaunchDimensions( + updates->shape(), ir_emitter_context_->device_description()); + UpdateLaunchDimensions(launch_dimensions, thunk, + ir_emitter_context_->llvm_module()); + + return ParallelLoopEmitter(loop_body_emitter, updates->shape(), + launch_dimensions, &b_) + .EmitLoop(IrName(scatter), + GetIndexTypeForKernel(scatter, launch_dimensions.launch_bound(), + &b_)); +} + Status IrEmitterUnnested::HandleSelect(HloInstruction* select) { thunk_sequence_->push_back( BuildKernelThunk(select, /*implements_whole_instruction=*/true)); @@ -2034,34 +2202,34 @@ Status IrEmitterUnnested::HandleSelect(HloInstruction* select) { Status IrEmitterUnnested::HandleSort(HloInstruction* sort) { std::vector> thunks; - auto keys = sort->operand(0); - auto values = sort->operand_count() > 1 ? sort->operand(1) : nullptr; - ShapeIndex keys_shape_index({}); - ShapeIndex values_shape_index({}); - if (values != nullptr) { - keys_shape_index = ShapeIndex({0}); - values_shape_index = ShapeIndex({1}); - } - auto keys_destination = GetAllocationSlice(*sort, keys_shape_index); - auto values_destination = GetAllocationSlice(*sort, values_shape_index); - - if (keys_destination != GetAllocationSlice(*keys)) { - thunks.push_back(absl::make_unique( - /*source_address=*/GetAllocationSlice(*keys), - /*destination_buffer=*/keys_destination, - /*mem_size=*/ShapeUtil::ByteSizeOf(keys->shape()), nullptr)); - } - if (values != nullptr && values_destination != GetAllocationSlice(*values)) { - // TODO(b/26783907): Figure out why we never seem to share buffers for - // key/value sort. - thunks.push_back(absl::make_unique( - /*source_address=*/GetAllocationSlice(*values), - /*destination_buffer=*/values_destination, - /*mem_size=*/ShapeUtil::ByteSizeOf(values->shape()), nullptr)); + Shape keys_shape = sort->operand(0)->shape(); + for (int64 i = 0; i < sort->operand_count(); ++i) { + ShapeIndex shape_index = + sort->operand_count() > 1 ? ShapeIndex({i}) : ShapeIndex({}); + // We assume that the layout of all involved operands and outputs is the + // same. + TF_RET_CHECK(LayoutUtil::LayoutsInShapesEqual(keys_shape, + sort->operand(i)->shape())); + TF_RET_CHECK(LayoutUtil::LayoutsInShapesEqual( + keys_shape, ShapeUtil::GetSubshape(sort->shape(), shape_index))); + + // If possible, we share buffers. If that is not possible, we need to copy + // the values, because the emitter does the sorting in-place. + auto destination_buffer = GetAllocationSlice(*sort, shape_index); + auto source_address = GetAllocationSlice(*sort->operand(i)); + if (destination_buffer != source_address) { + // TODO(b/26783907): Figure out why we never seem to share buffers for + // key/value sort. + thunks.push_back(absl::make_unique( + /*source_address=*/source_address, + /*destination_buffer=*/destination_buffer, + /*mem_size=*/ShapeUtil::ByteSizeOf(sort->operand(i)->shape()), + nullptr)); + } } int64 dimension_to_sort = sort->dimensions(0); - int64 dimension_to_sort_bound = keys->shape().dimensions(dimension_to_sort); + int64 dimension_to_sort_bound = keys_shape.dimensions(dimension_to_sort); int64 num_stages = tensorflow::Log2Ceiling(dimension_to_sort_bound); auto index_type = b_.getInt64Ty(); @@ -2085,7 +2253,7 @@ Status IrEmitterUnnested::HandleSort(HloInstruction* sort) { thunks.push_back( BuildKernelThunk(sort, /*implements_whole_instruction=*/false)); LaunchDimensions launch_dimensions = CalculateLaunchDimensions( - keys->shape(), ir_emitter_context_->device_description()); + keys_shape, ir_emitter_context_->device_description()); UpdateLaunchDimensions(launch_dimensions, thunks.back().get(), ir_emitter_context_->llvm_module()); @@ -2096,12 +2264,21 @@ Status IrEmitterUnnested::HandleSort(HloInstruction* sort) { xor_mask = llvm::ConstantInt::get(index_type, 1LL << mask); } + IrArray keys_array; + std::vector values_arrays; + values_arrays.reserve(sort->operand_count() - 1); + for (int64 i = 0; i < sort->operand_count(); ++i) { + ShapeIndex shape_index = + sort->operand_count() > 1 ? ShapeIndex({i}) : ShapeIndex({}); + if (i == 0) { + keys_array = GetIrArray(*sort, *sort, shape_index); + } else { + values_arrays.push_back(GetIrArray(*sort, *sort, shape_index)); + } + } TF_RETURN_IF_ERROR(llvm_ir::EmitSortInPlace( - dimension_to_sort, GetIrArray(*sort, *sort, keys_shape_index), - values != nullptr ? absl::make_optional( - GetIrArray(*sort, *sort, values_shape_index)) - : absl::nullopt, - IrName(sort), xor_mask, &b_, &launch_dimensions)); + dimension_to_sort, keys_array, values_arrays, IrName(sort), xor_mask, + &b_, &launch_dimensions)); } } @@ -2367,8 +2544,8 @@ std::unique_ptr IrEmitterUnnested::BuildKernelThunk( *slice.allocation()))); CHECK_NE(loc, nullptr); } else { - loc = b_.CreateInBoundsGEP(kernel_args.at(slice.allocation()), - {b_.getInt64(slice.offset())}); + loc = InBoundsGEP(kernel_args.at(slice.allocation()), + {b_.getInt64(slice.offset())}); } // If gte_index is nonempty, we have to dereference `loc` to get to the @@ -2376,8 +2553,8 @@ std::unique_ptr IrEmitterUnnested::BuildKernelThunk( llvm::Type* int8_double_pointer = llvm::PointerType::get(b_.getInt8PtrTy(), /*AddressSpace=*/0); for (int64 idx : gte_index) { - loc = b_.CreateBitCast(loc, int8_double_pointer); - loc = b_.CreateLoad(b_.CreateInBoundsGEP(loc, {b_.getInt64(idx)})); + loc = BitCast(loc, int8_double_pointer); + loc = Load(InBoundsGEP(loc, {b_.getInt64(idx)})); } bindings_.BindHloToIrValue(*instr, loc, index); @@ -2541,15 +2718,15 @@ std::unique_ptr IrEmitterUnnested::BuildFftThunk( } StatusOr> IrEmitterUnnested::BuildInitializerThunk( - const HloInstruction* hlo, const ShapeIndex& index) { + HloInstruction* hlo, const ShapeIndex& index) { bool fused = HloOpcode::kFusion == hlo->opcode(); - const HloInstruction* inst = fused ? hlo->fused_expression_root() : hlo; - const HloInstruction* init_value_operand = [&] { + HloInstruction* inst = fused ? hlo->fused_expression_root() : hlo; + HloInstruction* init_value_operand = [&] { switch (inst->opcode()) { case HloOpcode::kSelectAndScatter: - return inst->operand(2); + return inst->mutable_operand(2); case HloOpcode::kReduce: - return inst->operand(1); + return inst->mutable_operand(1); case HloOpcode::kTuple: CHECK(hlo->IsMultiOutputFusion()) << ": " << hlo->ToString() << " is not a multi-output fusion."; @@ -2557,7 +2734,7 @@ StatusOr> IrEmitterUnnested::BuildInitializerThunk( << ": Found '" << inst->operand(index.back())->opcode() << "' in " << inst->ToString() << " but expected 'reduce'."; // For multi-output fusion look through the tuple. - return inst->operand(index.back())->operand(1); + return inst->mutable_operand(index.back())->mutable_operand(1); default: LOG(FATAL) << "Opcode " << inst->opcode() << " should not need an initializer."; @@ -2584,7 +2761,7 @@ StatusOr> IrEmitterUnnested::BuildInitializerThunk( // Are all the bytes of this scalar equal to 0? If so, we can create a // MemzeroThunk. - ArraySlice literal_bytes( + absl::Span literal_bytes( reinterpret_cast(literal.untyped_data()), num_bytes); if (absl::c_all_of(literal_bytes, [](uint8 byte) { return byte == 0; })) { return {absl::make_unique(GetAllocationSlice(*hlo, index), @@ -2629,28 +2806,35 @@ StatusOr> IrEmitterUnnested::BuildInitializerThunk( ir_emitter_context_->device_description()); UpdateLaunchDimensions(launch_dimensions, kernel_thunk.get(), ir_emitter_context_->llvm_module()); - // If the init_value was fused into this reduce we have to generate it first. - if (fused && init_value_operand->opcode() != HloOpcode::kParameter) { - CHECK_EQ(HloOpcode::kConstant, init_value_operand->opcode()); - const Literal& literal = init_value_operand->literal(); - llvm::Constant* initializer = - llvm_ir::ConvertLiteralToIrConstant(literal, module_); + if (fused) { + // If init_value was fused into this reduce we have to generate it first. + std::vector parameter_arrays; + for (HloInstruction* operand : hlo->operands()) { + parameter_arrays.push_back(GetIrArray(*operand, *hlo)); + } + GpuElementalIrEmitter elemental_emitter(hlo_module_config_, + ir_emitter_context_->llvm_module(), + &b_, GetNestedComputer()); - llvm::GlobalVariable* global_for_const = new llvm::GlobalVariable( - *module_, initializer->getType(), - /*isConstant=*/true, llvm::GlobalValue::PrivateLinkage, initializer, - /*Name=*/""); - global_for_const->setAlignment(kConstantBufferAlignBytes); - bindings_.BindHloToIrValue(*init_value_operand, global_for_const); + FusedIrEmitter fused_emitter(parameter_arrays, &elemental_emitter); + TF_RETURN_IF_ERROR(init_value_operand->Accept(&fused_emitter)); + TF_RETURN_IF_ERROR( + ParallelLoopEmitter(fused_emitter.GetGenerator(init_value_operand), + GetIrArray(*hlo, *hlo, index), launch_dimensions, + &b_) + .EmitLoop(IrName(hlo))); + } else { + // In the unfused case the element is already there, just read from it. + TF_RETURN_IF_ERROR(ParallelLoopEmitter( + [=](const IrArray::Index& index) { + return GetIrArray(*init_value, *hlo) + .EmitReadArrayElement(index, &b_); + }, + GetIrArray(*hlo, *hlo, index), launch_dimensions, + &b_) + .EmitLoop(IrName(hlo))); } - TF_RETURN_IF_ERROR(ParallelLoopEmitter( - [=](const IrArray::Index& index) { - return GetIrArray(*init_value, *hlo) - .EmitReadArrayElement(index, &b_); - }, - GetIrArray(*hlo, *hlo, index), launch_dimensions, &b_) - .EmitLoop(IrName(hlo))); // Clean up state left behind by emitting the loop above. (This is normally // done in IrEmitterUnnested::Postprocess().) @@ -2674,8 +2858,7 @@ Status CheckHloBuffersShareAllocation( if (slice_a != slice_b) { return InternalError( "instruction %s %s does not share allocation with instruction %s %s", - a->ToString().c_str(), slice_a.ToString().c_str(), - b->ToString().c_str(), slice_b.ToString().c_str()); + a->ToString(), slice_a.ToString(), b->ToString(), slice_b.ToString()); } return Status::OK(); } @@ -2840,10 +3023,7 @@ Status IrEmitterUnnested::EmitTargetElementLoopInThunk( } // For multioutput fusion, we need to emit each operand and the root. - std::vector output_arrays; - for (int64 i = 0; i < ShapeUtil::TupleElementCount(hlo.shape()); ++i) { - output_arrays.push_back(GetIrArray(hlo, hlo, {i})); - } + std::vector output_arrays = ConstructIrArrayForOutputs(hlo); TF_RETURN_IF_ERROR( ParallelLoopEmitter(element_generator, output_arrays, launch_dimensions, &b_, unroll_factor) @@ -2851,12 +3031,9 @@ Status IrEmitterUnnested::EmitTargetElementLoopInThunk( GetIndexTypeForKernel( &hlo, launch_dimensions.launch_bound(), &b_))); - std::vector tuple_operand_ptrs; - for (int64 i = 0; i < output_arrays.size(); ++i) { - tuple_operand_ptrs.push_back(output_arrays[i].GetBasePointer()); - } b_.SetInsertPoint(b_.GetInsertBlock()->getTerminator()); - llvm_ir::EmitTuple(GetIrArray(hlo, hlo), tuple_operand_ptrs, &b_, module_); + llvm_ir::EmitTuple(GetIrArray(hlo, hlo), output_arrays, &b_, module_); + return Status::OK(); } @@ -2868,34 +3045,19 @@ Status IrEmitterUnnested::EmitTargetElementLoop( static_cast(LastThunk())); } -int IrEmitterUnnested::ConstructIrArrayForOutputs( - const HloInstruction& hlo, std::vector* output_arrays) { - int64 num_outputs = 1; - if (hlo.IsMultiOutputFusion()) { - num_outputs = ShapeUtil::TupleElementCount(hlo.shape()); - output_arrays->reserve(num_outputs); - for (int64 i = 0; i < num_outputs; ++i) { - output_arrays->push_back(GetIrArray(hlo, hlo, {i})); - } - } else { - output_arrays->push_back(GetIrArray(hlo, hlo)); - } - return num_outputs; -} - -int IrEmitterUnnested::ConstructIrArrayForInputs( - const HloInstruction& hlo, std::vector* param_arrays) { - int64 num_params = hlo.operands().size(); - param_arrays->reserve(num_params); +std::vector IrEmitterUnnested::ConstructIrArrayForInputs( + const HloInstruction& hlo) { + std::vector param_arrays; + param_arrays.reserve(hlo.operands().size()); for (const HloInstruction* param : hlo.operands()) { - param_arrays->push_back(GetIrArray(*param, hlo)); + param_arrays.push_back(GetIrArray(*param, hlo)); } - return num_params; + return param_arrays; } int IrEmitterUnnested::ConstructOutputReducedShapeAndCastOutputIrArrayToShape( const HloInstruction& hlo, const std::vector& output_arrays, - tensorflow::gtl::ArraySlice reduced_output_dims, + absl::Span reduced_output_dims, std::vector* output_reduced_shapes, std::vector* output_in_reduced_shape_arrays) { int64 num_outputs = 1; @@ -2922,7 +3084,7 @@ int IrEmitterUnnested::ConstructOutputReducedShapeAndCastOutputIrArrayToShape( int IrEmitterUnnested::ConstructInputReducedShapeAndCastInputIrArrayToShape( const HloInstruction& hlo, const std::vector& param_arrays, const std::vector& param_buffers, - tensorflow::gtl::ArraySlice reduced_output_dims, + absl::Span reduced_output_dims, std::vector* param_reduced_shapes, std::vector* param_in_reduced_shape_arrays) { int64 num_params = hlo.operands().size(); @@ -3063,18 +3225,18 @@ void EmitTiledElementalCodeWithBoundsCheck( // TODO(b/33320379): Here each block transposes 1 tile. It may be more efficient // to launch fewer blocks so each transposes many tiles. LaunchDimensions IrEmitterUnnested::EmitHlo021Tile( - HloInstruction* hlo, tensorflow::gtl::ArraySlice reduced_output_dims, - tensorflow::gtl::ArraySlice tiled_param_ids) { + HloInstruction* hlo, absl::Span reduced_output_dims, + absl::Span tiled_param_ids) { // Parameters for the tiling algorithm. constexpr int64 kTileSize = 32; constexpr int64 kNumRows = 4; constexpr int64 kThreadsPerTile = kTileSize * kNumRows; // Construct IrArrays for the inputs and outputs. - std::vector output_arrays; - int64 num_outputs = ConstructIrArrayForOutputs(*hlo, &output_arrays); - std::vector param_arrays; - int64 num_params = ConstructIrArrayForInputs(*hlo, ¶m_arrays); + std::vector output_arrays = ConstructIrArrayForOutputs(*hlo); + int64 num_outputs = output_arrays.size(); + std::vector param_arrays = ConstructIrArrayForInputs(*hlo); + int64 num_params = param_arrays.size(); // Allocate shared memory buffers to store the tiled inputs. std::vector param_shmem_buffers(num_params, nullptr); @@ -3155,9 +3317,8 @@ LaunchDimensions IrEmitterUnnested::EmitHlo021Tile( const IrArray::Index output_tile_origin = [&] { IrArray::Index index = output_tile_index; for (int i = 1; i < 3; ++i) { - index[i] = - b_.CreateMul(output_tile_index[i], index_typed_constant(kTileSize), - "tile_origin." + std::to_string(i)); + index[i] = Mul(output_tile_index[i], index_typed_constant(kTileSize), + "tile_origin." + std::to_string(i)); } return index; }(); @@ -3170,12 +3331,12 @@ LaunchDimensions IrEmitterUnnested::EmitHlo021Tile( std::vector output_tile_bounds(3); for (int i = 1; i < 3; ++i) { // Only last row or column may not have full size. - output_tile_bounds[i] = b_.CreateSelect( - b_.CreateICmpEQ(output_tile_index[i], - index_typed_constant(output_dims_in_tiles[i] - 1)), - index_typed_constant(reduced_output_dims[i] - - (output_dims_in_tiles[i] - 1) * kTileSize), - index_typed_constant(kTileSize), "kTileSize"); + output_tile_bounds[i] = + Select(ICmpEQ(output_tile_index[i], + index_typed_constant(output_dims_in_tiles[i] - 1)), + index_typed_constant(reduced_output_dims[i] - + (output_dims_in_tiles[i] - 1) * kTileSize), + index_typed_constant(kTileSize), "kTileSize"); } KernelSupportLibrary ksl(&b_, llvm_ir::UnrollMode::kDefaultUnroll); @@ -3193,7 +3354,7 @@ LaunchDimensions IrEmitterUnnested::EmitHlo021Tile( // Adds `addend` to the given `dim` of `index`. auto offset_dim = [&](IrArray::Index index, llvm::Value* addend, int64 dim) { - index[dim] = b_.CreateAdd(index[dim], addend); + index[dim] = Add(index[dim], addend); return index; }; const IrArray::Index input_index = @@ -3209,10 +3370,9 @@ LaunchDimensions IrEmitterUnnested::EmitHlo021Tile( llvm::Value* shmem_buffer = param_shmem_buffers[id]; // TODO(jlebar): Add AA metadata to this store. Tile buffers are // global variables, so LLVM can't infer much about it. - b_.CreateStore( - input_in_logical_shape.EmitReadArrayElement(index, &b_, - "input_element"), - b_.CreateGEP(shmem_buffer, {index_typed_constant(0), y_loc, x})); + Store(input_in_logical_shape.EmitReadArrayElement(index, &b_, + "input_element"), + GEP(shmem_buffer, {index_typed_constant(0), y_loc, x})); } }); @@ -3233,9 +3393,9 @@ LaunchDimensions IrEmitterUnnested::EmitHlo021Tile( output_index, "output", output_tile_bounds[2], output_tile_bounds[1], [&](const IrArray::Index& index, llvm::Value* y_loc) { // TODO(jlebar): Add AA metadata to this load. - llvm::Instruction* load_from_shmem_buffer = b_.CreateLoad( - b_.CreateGEP(param_shmem_buffers[0], {b_.getInt64(0), x, y_loc}), - "output_element"); + llvm::Instruction* load_from_shmem_buffer = + Load(GEP(param_shmem_buffers[0], {b_.getInt64(0), x, y_loc}), + "output_element"); output_in_reduced_shape_arrays[0].EmitWriteArrayElement( index, load_from_shmem_buffer, &b_); }); @@ -3263,7 +3423,7 @@ LaunchDimensions IrEmitterUnnested::EmitHlo021Tile( output_in_reduced_shape_arrays.size()); for (int64 i = 0; i < output_in_reduced_shape_arrays.size(); ++i) { output_in_reduced_shape_arrays[i].EmitWriteArrayElement( - index, b_.CreateExtractValue(output_value, i), &b_); + index, ExtractValue(output_value, i), &b_); } } else { output_in_reduced_shape_arrays[0].EmitWriteArrayElement( @@ -3274,12 +3434,7 @@ LaunchDimensions IrEmitterUnnested::EmitHlo021Tile( // For multioutput fusion, emit a tuple with all the individual outputs. if (hlo->IsMultiOutputFusion()) { - std::vector tuple_operand_ptrs; - for (int64 i = 0; i < output_arrays.size(); ++i) { - tuple_operand_ptrs.push_back(output_arrays[i].GetBasePointer()); - } - llvm_ir::EmitTuple(GetIrArray(*hlo, *hlo), tuple_operand_ptrs, &b_, - module_); + llvm_ir::EmitTuple(GetIrArray(*hlo, *hlo), output_arrays, &b_, module_); } return launch_dimensions; @@ -3312,7 +3467,7 @@ bool IrEmitterUnnested::CheckAndEmitHloWithTile021(HloInstruction* hlo) { if (!reduced_dims_021.has_value()) { reduced_dims_021 = curr_reduced_dims_021; } - if (!ContainersEqual(*reduced_dims_021, curr_reduced_dims_021)) { + if (!absl::c_equal(*reduced_dims_021, curr_reduced_dims_021)) { // There is more than one possible transpose. Instead of picking one // transpose, we simply give up here. return false; @@ -3345,7 +3500,7 @@ bool IrEmitterUnnested::CheckAndEmitHloWithTile021(HloInstruction* hlo) { // if there's a Right Choice. // // This is only sound if tiled transposes are the only place where we use - // shared memory in fusions. If in the future other fusile ops use shared + // shared memory in fusions. If in the future other fusible ops use shared // memory, we'll have to adjust this heuristic. constexpr int kMinBlocksPerCore = 3; constexpr int64 kShmemPerCore = 48 * 1024; diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h index 525441990795e160ba0e8facb910d5cc9796c4bb..93f11c069a4cebdf3c79cba17c824eded4f4b1db 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h @@ -76,6 +76,7 @@ class IrEmitterUnnested : public IrEmitter { Status HandleInfeed(HloInstruction* xla_infeed) override; Status HandleOutfeed(HloInstruction* outfeed) override; Status HandleRng(HloInstruction* random) override; + Status HandleScatter(HloInstruction* scatter) override; Status HandleSelect(HloInstruction* select) override; Status HandleSort(HloInstruction* sort) override; Status HandleTupleSelect(HloInstruction* tuple_select) override; @@ -105,13 +106,12 @@ class IrEmitterUnnested : public IrEmitter { // This kernel takes as arguments pointers to the given buffer allocations. llvm::Function* BuildKernelPrototype( const HloInstruction& inst, - tensorflow::gtl::ArraySlice args); + absl::Span args); // Helper for writing extra outputs from inside a reduce kernel. Status EmitExtraOutputsForReduce( const HloInstruction* reduce, const llvm_ir::IrArray::Index& index, - tensorflow::gtl::ArraySlice< - std::pair> + absl::Span> extra_output_gens); // EmitColumnReduction and EmitRowReduction emit code for column and row @@ -127,12 +127,11 @@ class IrEmitterUnnested : public IrEmitter { Status EmitColumnReduction( int64 height, int64 width, HloInstruction* reduce, const Shape& input_shape, - tensorflow::gtl::ArraySlice input_gens, - tensorflow::gtl::ArraySlice init_value_gens, - tensorflow::gtl::ArraySlice reducers, - tensorflow::gtl::ArraySlice reduce_output_shapes, - tensorflow::gtl::ArraySlice< - std::pair> + absl::Span input_gens, + absl::Span init_value_gens, + absl::Span reducers, + absl::Span reduce_output_shapes, + absl::Span> extra_output_gens); // Emits code that reduces a 3D tensor of shape [depth x height x width] to a @@ -143,23 +142,21 @@ class IrEmitterUnnested : public IrEmitter { Status EmitRowReduction( int64 depth, int64 height, int64 width, HloInstruction* reduce, const Shape& input_shape, - tensorflow::gtl::ArraySlice input_gens, - tensorflow::gtl::ArraySlice init_value_gens, - tensorflow::gtl::ArraySlice reducers, - tensorflow::gtl::ArraySlice reduce_output_shapes, - tensorflow::gtl::ArraySlice< - std::pair> + absl::Span input_gens, + absl::Span init_value_gens, + absl::Span reducers, + absl::Span reduce_output_shapes, + absl::Span> extra_output_gens); // Emits code that reduces a tensor of arbitrary rank to a scalar. Status EmitReductionToScalar( HloInstruction* reduce, const Shape& input_shape, - tensorflow::gtl::ArraySlice input_gens, - tensorflow::gtl::ArraySlice init_value_gens, - tensorflow::gtl::ArraySlice reducers, - tensorflow::gtl::ArraySlice reduce_output_shapes, - tensorflow::gtl::ArraySlice< - std::pair> + absl::Span input_gens, + absl::Span init_value_gens, + absl::Span reducers, + absl::Span reduce_output_shapes, + absl::Span> extra_output_gens); // Figures out whether `reduce` is a row or column reduction, and which @@ -180,33 +177,37 @@ class IrEmitterUnnested : public IrEmitter { // Prerequisite: `IsReductionToVector(*reduce)` Status EmitReductionToVector( HloInstruction* reduce, const Shape& input_shape, - tensorflow::gtl::ArraySlice input_gens, - tensorflow::gtl::ArraySlice init_value_gens, - tensorflow::gtl::ArraySlice dimensions_to_reduce, - tensorflow::gtl::ArraySlice reducers, - tensorflow::gtl::ArraySlice reduce_output_shapes, - tensorflow::gtl::ArraySlice< - std::pair> + absl::Span input_gens, + absl::Span init_value_gens, + absl::Span dimensions_to_reduce, + absl::Span reducers, + absl::Span reduce_output_shapes, + absl::Span> extra_output_gens); + // Emits code for an in-place scatter, modifying `thunk`s launch dimensions in + // the process. `scatter` may be fused, scatter indices are taken from + // `scatter_indices_gen`, updates from`updates_gen`. The output buffer is + // expected to have the operand values in it already. + Status EmitScatter(Thunk* thunk, HloInstruction* scatter, + const llvm_ir::ElementGenerator& scatter_indices_gen, + const llvm_ir::ElementGenerator& updates_gen); + // Returns true if a 0-2-1 tiling algorithm is already used to emit the kernel // for the hlo instruction. bool CheckAndEmitHloWithTile021(HloInstruction* hlo); // Emits a kernel for the hlo instruction using a 0-2-1 tiling algorithm and // returns the launch dimensions for the kernel. This is a helper to support // the implementation of CheckAndEmitHloWithTile021. - LaunchDimensions EmitHlo021Tile( - HloInstruction* hlo, - tensorflow::gtl::ArraySlice reduced_output_dims, - tensorflow::gtl::ArraySlice tiled_param_ids); - // Generates the IrArray for each output of hlo and returns the number of - // outputs. - int ConstructIrArrayForOutputs(const HloInstruction& hlo, - std::vector* output_arrays); - // Generates the IrArray for each input of hlo and returns the number of - // inputs. - int ConstructIrArrayForInputs(const HloInstruction& hlo, - std::vector* param_arrays); + LaunchDimensions EmitHlo021Tile(HloInstruction* hlo, + absl::Span reduced_output_dims, + absl::Span tiled_param_ids); + + // Generates the IrArray for each input of an hlo and returns a vector that + // constains such IrArrays. + std::vector ConstructIrArrayForInputs( + const HloInstruction& hlo); + // For each output of the `hlo` instruction, constructs the reduced shape for // the output with the given `reduced_output_dims` and cast the original // output IrArray element in `output_arrays` to the reduced shape. Returns @@ -214,7 +215,7 @@ class IrEmitterUnnested : public IrEmitter { int ConstructOutputReducedShapeAndCastOutputIrArrayToShape( const HloInstruction& hlo, const std::vector& output_arrays, - tensorflow::gtl::ArraySlice reduced_output_dims, + absl::Span reduced_output_dims, std::vector* output_reduced_shapes, std::vector* output_in_reduced_shape_arrays); // For each input of the `hlo` instruction, checks its value in @@ -226,7 +227,7 @@ class IrEmitterUnnested : public IrEmitter { const HloInstruction& hlo, const std::vector& param_arrays, const std::vector& param_buffers, - tensorflow::gtl::ArraySlice reduced_output_dims, + absl::Span reduced_output_dims, std::vector* param_reduced_shapes, std::vector* param_in_reduced_shape_arrays); @@ -250,7 +251,7 @@ class IrEmitterUnnested : public IrEmitter { // Returns a thunk that, given a reduce or select-and-scatter op, initializes // its memory to the appropriate initial value. StatusOr> BuildInitializerThunk( - const HloInstruction* hlo, const ShapeIndex& index = {}); + HloInstruction* hlo, const ShapeIndex& index = {}); // Returns a thunk that calls host-to-device cuMemcpy to implement `inst`. std::unique_ptr BuildHostToDeviceCopyThunk(const HloInstruction* inst); diff --git a/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc b/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc index d856299889fa7598acc78f3b8a5f5d613c58271d..e09b8fbd3ba275e14accbf88c21f3d10f34198d9 100644 --- a/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc @@ -27,10 +27,10 @@ limitations under the License. namespace xla { namespace gpu { -KernelThunk::KernelThunk( - tensorflow::gtl::ArraySlice args, - const string& kernel_name, const HloInstruction* hlo_instruction, - int unroll_factor) +KernelThunk::KernelThunk(absl::Span args, + const string& kernel_name, + const HloInstruction* hlo_instruction, + int unroll_factor) : Thunk(Kind::kKernel, hlo_instruction), args_(args.begin(), args.end()), kernel_name_(kernel_name), @@ -41,11 +41,7 @@ Status KernelThunk::Initialize(const GpuExecutable& executable, tensorflow::mutex_lock lock(mutex_); if (!loader_spec_) { loader_spec_.reset(new se::MultiKernelLoaderSpec(args_.size())); - absl::string_view ptx = executable.ptx(); - // Convert absl::string_view to se::port::StringPiece because - // StreamExecutor uses the latter. - loader_spec_->AddCudaPtxInMemory( - se::port::StringPiece(ptx.data(), ptx.size()), kernel_name_); + loader_spec_->AddCudaPtxInMemory(executable.ptx(), kernel_name_); if (!executable.cubin().empty()) { loader_spec_->AddCudaCubinInMemory( @@ -63,7 +59,7 @@ Status KernelThunk::Initialize(const GpuExecutable& executable, if (kernel_cache_.end() == it) { it = kernel_cache_.emplace(executor, se::KernelBase(executor)).first; if (!executor->GetKernel(*loader_spec_, &it->second)) { - return InternalError("Unable to load kernel %s", kernel_name_.c_str()); + return InternalError("Unable to load kernel %s", kernel_name_); } } @@ -107,7 +103,7 @@ Status KernelThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, stream, se::ThreadDim(launch_dimensions.threads_per_block()), se::BlockDim(launch_dimensions.block_count()), *kernel, *kernel_args)) { - return InternalError("Unable to launch kernel %s", kernel_name_.c_str()); + return InternalError("Unable to launch kernel %s", kernel_name_); } return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/gpu/kernel_thunk.h b/tensorflow/compiler/xla/service/gpu/kernel_thunk.h index d751de50ad6671b3bf88cd4de49a8feb448e13ba..f63db5c3696f8f3bbd5956724240b2b06b4f1b98 100644 --- a/tensorflow/compiler/xla/service/gpu/kernel_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/kernel_thunk.h @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" @@ -27,7 +28,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/thunk.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/thread_annotations.h" @@ -47,7 +47,7 @@ class KernelThunk : public Thunk { // Constructs a thunk for the given kernel. // // `hlo_instruction` is as in Thunk. Other arguments are as the class members. - KernelThunk(tensorflow::gtl::ArraySlice args, + KernelThunk(absl::Span args, const string& kernel_name, const HloInstruction* hlo_instruction, int unroll_factor); KernelThunk(const KernelThunk&) = delete; diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD index ccf082c4c65f91bf92e5d8a934c09150ad27ef50..698d2d51cc81a6c87f6578f1f35cdb47cf6bb4f2 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD @@ -36,6 +36,7 @@ cc_library( "//tensorflow/core:lib_internal", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@llvm//:amdgpu_code_gen", "@llvm//:analysis", "@llvm//:bit_reader", diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/dump_ir_pass.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/dump_ir_pass.cc index a3c74507ddc2ffdbcea6ea4ef97b6f7b0cf250a5..85bc58cb445627695a46171db64cd8a1f10e0fc8 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/dump_ir_pass.cc +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/dump_ir_pass.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/dump_ir_pass.h" +#include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "llvm/IR/Module.h" #include "llvm/Support/FileSystem.h" @@ -22,7 +23,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/io/path.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" namespace xla { @@ -87,9 +87,10 @@ void IrDumpingPassManager::run(llvm::Module &module) { llvm::PassRegistry::getPassRegistry()->getPassInfo(P->getPassID()); const string basename = ReplaceFilenameExtension( absl::string_view(tensorflow::io::Basename(input_filename_)), - tensorflow::strings::Printf( + absl::StrFormat( "pass-%02d.before.%s.ll", i, - (PI == nullptr ? "unknown" : PI->getPassArgument().data()))); + absl::string_view(PI == nullptr ? "unknown" + : PI->getPassArgument().data()))); llvm::legacy::PassManager::add( new DumpIrPass(tensorflow::io::JoinPath(output_dir_, basename))); } diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc index e18d7e764a880195ab183f754fc17d07c7f17a2f..8751e3a9c2a4c8da46d3ecd8437629450d4a2ba2 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc @@ -57,7 +57,6 @@ limitations under the License. #include "llvm/Transforms/Scalar.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/io/path.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/tracing.h" diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc index 9fb6f569ae5f950b7dd9befb1ad4865ab941bd48..835924024b7b7de79624a369a69b07d72ac751ab 100644 --- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc @@ -24,13 +24,14 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h" #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -86,67 +87,13 @@ bool GpuMultiOutputFusion::ShapesCompatibleForFusion(HloInstruction* instr1, get_element_shape(element_instr_1), get_element_shape(element_instr_2)); } -namespace { -bool IsInputFusibleReduction(HloInstruction* instr) { - if (instr->IsMultiOutputFusion()) { - for (const HloInstruction* operand : - instr->fused_expression_root()->operands()) { - if (operand->opcode() == HloOpcode::kReduce) { - CHECK(instr->fusion_kind() == HloInstruction::FusionKind::kInput) - << " Reduce multi-output fusion " << instr->ToString() - << " must be an input fusion."; - return true; - } - } - return false; - } else if (instr->opcode() == HloOpcode::kFusion) { - // The loop emitter can handle to-vector reduce fusions. Such reduce - // fusions have the fusion kind kLoop rather than kInput. We do not fuse - // to-vector reduce fusions, because the resulting fusions may no longer be - // supported by loop emitter. - return IsReductionToVector(*instr->fused_expression_root()); - } else { - return IsReductionToVector(*instr); - } -} - -// The code emitted for reduction suffers from poor data locality if the layouts -// of input parameters differ. In such situtations it is beneficial not to fuse. -// We consider input params with maximum rank only. Params with smaller ranks -// will be broadcasted and have not been observed to cause data locality issues. -// TODO(b/111977086): Improve reduce emitters to remove this limitation. -bool ReduceFriendlyInputLayouts(HloInstruction* instr) { - std::vector params; - if (instr->opcode() == HloOpcode::kFusion) { - params = instr->fused_parameters(); - } else { - for (HloInstruction* operand : instr->operands()) { - params.push_back(operand); - } - } - int64 max_rank = 0; - const Layout* max_rank_layout; - for (HloInstruction* param : params) { - if (ShapeUtil::Rank(param->shape()) > max_rank) { - max_rank = ShapeUtil::Rank(param->shape()); - max_rank_layout = ¶m->shape().layout(); - } - } - return absl::c_all_of(params, [&](HloInstruction* param) { - return (ShapeUtil::Rank(param->shape()) < max_rank) || - (LayoutUtil::Equal(param->shape().layout(), *max_rank_layout)); - }); -} - -} // namespace - bool GpuMultiOutputFusion::IsFusible(HloInstruction* instr) { // We can fuse reduces and loop fusions. Elementwise instructions can be fused // with any other instruction. // TODO(b/112957171): This should use the same isFusible logic as // instruction_fusion. - return instr->IsFusable() && - (IsInputFusibleReduction(instr) || + return instr->IsFusible() && + (IsInputFusibleReduction(*instr) || (instr->opcode() == HloOpcode::kFusion && instr->fusion_kind() == HloInstruction::FusionKind::kLoop) || instr->IsElementwise()); @@ -154,7 +101,7 @@ bool GpuMultiOutputFusion::IsFusible(HloInstruction* instr) { int64 GpuMultiOutputFusion::GetProfit(HloInstruction* instr1, HloInstruction* instr2) { - tensorflow::gtl::FlatSet in_list; + absl::flat_hash_set in_list; for (auto instr : instr1->operands()) { if (!IsProfitableOperand(instr)) { continue; @@ -201,10 +148,10 @@ bool GpuMultiOutputFusion::DoProducerConsumerMultiOutputFusion() { bool changed = false; RecomputeReachability(); - tensorflow::gtl::FlatSet to_fuse; + absl::flat_hash_set to_fuse; // Keep a list of the instructions to fuse after making all the fusion // decisions. We first aggressively add instructions to potential_fusion_list, - // then filter out instructions that will be no longer fusable because of + // then filter out instructions that will be no longer fusible because of // reachability change. This avoids recalculating reachability on a large set // of instructions. std::vector> @@ -219,8 +166,8 @@ bool GpuMultiOutputFusion::DoProducerConsumerMultiOutputFusion() { VLOG(3) << consumer->name() << " has no users."; continue; } - if (!IsInputFusibleReduction(consumer)) { - VLOG(3) << consumer->name() << " is not an input-fusable reduction."; + if (!IsInputFusibleReduction(*consumer)) { + VLOG(3) << consumer->name() << " is not an input-fusible reduction."; continue; } VLOG(3) << consumer->name() @@ -229,8 +176,8 @@ bool GpuMultiOutputFusion::DoProducerConsumerMultiOutputFusion() { auto consumer_operands = consumer->operands(); for (size_t i = 0; i < consumer_operands.size(); ++i) { HloInstruction* producer = consumer_operands[i]; - if (!producer->IsFusable()) { - VLOG(3) << producer->name() << " is not fusable."; + if (!producer->IsFusible()) { + VLOG(3) << producer->name() << " is not fusible."; continue; } const bool is_loop_fusion = @@ -244,7 +191,7 @@ bool GpuMultiOutputFusion::DoProducerConsumerMultiOutputFusion() { VLOG(3) << producer->name() << " has an incompatible shape."; continue; } - if (!ReduceFriendlyInputLayouts(producer)) { + if (!LayoutsAreReduceInputFusionFriendly(*producer, *consumer)) { VLOG(3) << producer->name() << " has inputs with mixed layouts."; continue; } @@ -270,7 +217,7 @@ bool GpuMultiOutputFusion::DoProducerConsumerMultiOutputFusion() { } } - // Filter out pairs that will be no longer fusable because of reachability + // Filter out pairs that will be no longer fusible because of reachability // change. for (auto& fusion_pair : potential_fusion_list) { HloInstruction* producer = fusion_pair.first; diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.h b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.h index 67ca5d49eee8508e93284b134f8410eb3a89f9ce..f0b4d67ab8463a39161f71908746cad9e2a8670a 100644 --- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.h +++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.h @@ -22,7 +22,7 @@ namespace xla { namespace gpu { // Multi-output fusion of sibling and producer-consumer instructions for the -// Jellyfish backend. +// GPU backend. class GpuMultiOutputFusion : public MultiOutputFusion { public: GpuMultiOutputFusion(); diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc index c822c94f1b102e02be4a13a35892a2c181702383..8a6e5327e082791ff857a89e840c6a4f045f0edb 100644 --- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc +++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc @@ -259,7 +259,7 @@ TEST_F(MultiOutputFusionTest, MultiOutputFusionTwoLoops) { TEST_F(MultiOutputFusionTest, MultiOutputFusionLoopReduceToInputFusion) { // Fusing a reduce into a loop fusion would require changing the fusion kind. // That's not supported yet. - auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( fused_computation_1 { p0.1 = f32[6400]{0} parameter(0) ROOT mul = f32[6400]{0} multiply(p0.1, p0.1) @@ -277,7 +277,7 @@ TEST_F(MultiOutputFusionTest, MultiOutputFusionLoopReduceToInputFusion) { } TEST_F(MultiOutputFusionTest, MultiOutputFusionLoopElementwise) { - auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( fused_computation_1 { p0.1 = f32[6400]{0} parameter(0) ROOT mul = f32[6400]{0} multiply(p0.1, p0.1) @@ -301,7 +301,7 @@ TEST_F(MultiOutputFusionTest, MultiOutputFusionLoopElementwise) { } TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingLoopsDifferentShapes) { - auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( fused_computation_1 { p0.1 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0) ROOT mul = f32[8,1,5,16,1,1]{5,4,3,2,1,0} multiply(p0.1, p0.1) @@ -324,7 +324,7 @@ TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingLoopsDifferentShapes) { } TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingLoopAndMultiOutputLoop) { - auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( fused_computation_1 { p0.1 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0) mul = f32[8,1,5,16,1,1]{5,4,3,2,1,0} multiply(p0.1, p0.1) @@ -358,7 +358,7 @@ TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingLoopAndMultiOutputLoop) { TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingLoopAndMultiOutputLoopDifferentShapes) { - auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( fused_computation_1 { p0.1 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0) mul = f32[8,1,5,16,1,1]{5,4,3,2,1,0} multiply(p0.1, p0.1) diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc index 695feadb11ce9a3baf0c6732a9f6df61a4fcd308..791d414c915e6f23d84a38ae99dcfa9a59ab6353 100644 --- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc @@ -36,26 +36,26 @@ limitations under the License. #include "tensorflow/compiler/xla/service/buffer_liveness.h" #include "tensorflow/compiler/xla/service/call_inliner.h" #include "tensorflow/compiler/xla/service/conditional_simplifier.h" -#include "tensorflow/compiler/xla/service/convolution_feature_group_converter.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h" -#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h" -#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h" +#include "tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.h" +#include "tensorflow/compiler/xla/service/gpu/cudnn_conv_pad_for_tensor_cores.h" +#include "tensorflow/compiler/xla/service/gpu/cudnn_conv_padding_legalization.h" +#include "tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter.h" +#include "tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.h" #include "tensorflow/compiler/xla/service/gpu/fusion_merger.h" #include "tensorflow/compiler/xla/service/gpu/gpu_constants.h" #include "tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h" #include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h" #include "tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h" #include "tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h" -#include "tensorflow/compiler/xla/service/gpu/hlo_schedule.h" #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h" #include "tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h" #include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.h" #include "tensorflow/compiler/xla/service/gpu/multi_output_fusion.h" -#include "tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h" -#include "tensorflow/compiler/xla/service/gpu/pad_insertion.h" #include "tensorflow/compiler/xla/service/gpu/partition_assignment.h" #include "tensorflow/compiler/xla/service/gpu/stream_assignment.h" #include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h" @@ -75,7 +75,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/service/reduce_precision_insertion.h" #include "tensorflow/compiler/xla/service/reshape_mover.h" -#include "tensorflow/compiler/xla/service/scatter_expander.h" #include "tensorflow/compiler/xla/service/transpose_folding.h" #include "tensorflow/compiler/xla/service/tuple_simplifier.h" #include "tensorflow/compiler/xla/service/while_loop_constant_sinking.h" @@ -176,8 +175,6 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, // elimination has to come after that pass. pipeline.AddPass(); - pipeline.AddPass(); - pass.AddPass( /*is_layout_sensitive=*/false, [](const Shape&, const Shape&) { return false; }); @@ -204,20 +201,23 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, { // Convert convolutions into CustomCalls to cudnn, then canonicalize them - // (PadInsertion). + // (CudnnConvPaddingLegalization). HloPassPipeline pipeline("conv_canonicalization"); pipeline.AddInvariantChecker(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false); - // TODO(b/31709653): Directly use the grouped convolution support of Cudnn. - pipeline.AddPass(); - pipeline.AddPass(); - pipeline.AddPass(); + pipeline.AddPass(); + pipeline.AddPass(); + pipeline.AddPass(); if (IsVoltaOrLater(*stream_exec)) { - pipeline.AddPass(); - // PadForTensorCores leaves behind unnecessary tuple/get-tuple-element - // pairs that TupleSimplifier fixes. + pipeline.AddPass(); + // CudnnConvPadForTensorCores leaves behind unnecessary + // tuple/get-tuple-element pairs that TupleSimplifier fixes. pipeline.AddPass(); } + // CudnnConvRewriter, CudnnConvPaddingLegalization and + // CudnnConvPadForTensorCores may add instructions which can be simplified + // by constant folding. + pipeline.AddPass(); TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); } @@ -230,14 +230,17 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, // a layout-sensitive verifier! HloPassPipeline pipeline("layout assignment"); pipeline.AddPass( - hlo_module->mutable_entry_computation_layout(), stream_exec); + hlo_module->mutable_entry_computation_layout(), + LayoutAssignment::InstructionCanChangeLayout, stream_exec); TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); } { HloPassPipeline pipeline("post-layout_assignment"); - pipeline.AddInvariantChecker(/*layout_sensitive=*/true, - /*allow_mixed_precision=*/false); + pipeline.AddInvariantChecker( + /*layout_sensitive=*/true, + /*allow_mixed_precision=*/false, + LayoutAssignment::InstructionCanChangeLayout); // The LayoutAssignment pass may leave behind kCopy instructions which are // duplicate or NOPs, so remove them with algebraic simplification and CSE. @@ -250,7 +253,7 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, // Choose the fastest algorithm for each conv. // // We pick the algorithm before fusion so we can generate better HLO. After - // CudnnConvolutionRewriter, our convolutions are CustomCalls which return a + // CudnnConvRewriter, our convolutions are CustomCalls which return a // tuple (conv_result, scratch_memory), and the each conv uses 0 bytes of // scratch: // @@ -268,12 +271,12 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, // The new tuple and gte instructions then be simplified away, because // nobody is expected to use the scratch value. // - // However, if we were to run CudnnConvolutionAlgorithmPicker after fusion + // However, if we were to run CudnnConvAlgorithmPicker after fusion // the gte(customcall, 0) would probably already be into a fusion node. We // can't simplify across HloComputation boundaries, so in this case we // wouldn't be able to simplify away the new_tuple bits. - pipeline.AddPass( - stream_exec, device_allocator, compiler); + pipeline.AddPass(stream_exec, device_allocator, + compiler); // Clean up new_tuple described above. pipeline.AddPass(); @@ -283,8 +286,10 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, { HloPassFix fusion("fusion"); - fusion.AddInvariantChecker(/*layout_sensitive=*/true, - /*allow_mixed_precision=*/false); + fusion.AddInvariantChecker( + /*layout_sensitive=*/true, + /*allow_mixed_precision=*/false, + LayoutAssignment::InstructionCanChangeLayout); fusion.AddPass(/*may_duplicate=*/false); fusion.AddPass(/*may_duplicate=*/true); fusion.AddPass(); @@ -296,7 +301,8 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, HloPassPipeline reduce_pipeline("reduce-precision"); reduce_pipeline.AddInvariantChecker( - /*is_layout_sensitive=*/true, /*allow_mixed_precision=*/false); + /*is_layout_sensitive=*/true, /*allow_mixed_precision=*/false, + LayoutAssignment::InstructionCanChangeLayout); ReducePrecisionInsertion::AddPasses( &reduce_pipeline, hlo_module->config().debug_options(), ReducePrecisionInsertion::PassTiming::AFTER_FUSION); @@ -322,8 +328,10 @@ Status PrepareHloModuleForIrEmitting(HloModule* hlo_module) { // (b/27180329). Therefore, in that case, we set the output to be a copy of // the parameter. HloPassPipeline pipeline("GPU-ir-emit-prepare"); - pipeline.AddInvariantChecker(/*layout_sensitive=*/true, - /*allow_mixed_precision=*/false); + pipeline.AddInvariantChecker( + /*layout_sensitive=*/true, + /*allow_mixed_precision=*/false, + LayoutAssignment::InstructionCanChangeLayout); // Copy insertion should be performed immediately before IR emission to avoid // inserting unnecessary copies (later pass adds an instruction which @@ -398,11 +406,11 @@ void WarnIfBadPtxasVersion(const string& ptxas_path) { "prefers >= 9.2.88). Compilation of XLA kernels below will likely " "fail.\n\nYou do not need to update CUDA; cherry-picking the ptxas " "binary is sufficient."; - } else if ((vmaj < 9 || vmin < 2 || vdot < 88)) { + } else if (std::make_tuple(vmaj, vmin, vdot) < std::make_tuple(9, 2, 88)) { LOG(WARNING) << "*** WARNING *** You are using ptxas " << vmaj << "." << vmin << "." << vdot - << ", which older than 9.2.88. ptxas 9.x before 9.2.88 is known to " + << ", which is older than 9.2.88. ptxas 9.x before 9.2.88 is known to " "miscompile XLA code, leading to incorrect results or " "invalid-address errors.\n\nYou do not need to update to CUDA " "9.2.88; cherry-picking the ptxas binary is sufficient."; @@ -565,8 +573,8 @@ StatusOr> NVPTXCompiler::RunBackend( // must also be used to determine the thunk launch schedule. std::unique_ptr stream_assignment = AssignStreams(*module); TF_ASSIGN_OR_RETURN( - std::unique_ptr hlo_schedule, - HloSchedule::Build(*module, *stream_assignment, pointer_size_)); + std::unique_ptr hlo_schedule, + GpuHloSchedule::Build(*module, *stream_assignment, pointer_size_)); // Run buffer analysis on the HLO graph. This analysis figures out which // temporary buffers are required to run the computation. @@ -817,9 +825,8 @@ std::vector NVPTXCompiler::CompilePtxOrGetCachedResult(const string& ptx, } StatusOr>> -NVPTXCompiler::CompileAheadOfTime( - std::vector> module, - const AotCompilationOptions& options) { +NVPTXCompiler::CompileAheadOfTime(std::unique_ptr module_group, + const AotCompilationOptions& options) { return Unimplemented( "not yet implemented: NVPTXCompiler::CompileAheadOfTime"); } diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h index 08ef6ef56c5e2637447255c5c7eb5b309cada80e..f79ae2990ae7d6e6985b15727a72358289121aa9 100644 --- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h @@ -20,13 +20,14 @@ limitations under the License. #include #include +#include "absl/container/node_hash_map.h" #include "absl/types/optional.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/service/executable.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/llvm_compiler.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/mutex.h" @@ -58,7 +59,7 @@ class NVPTXCompiler : public LLVMCompiler { DeviceMemoryAllocator* device_allocator) override; StatusOr>> - CompileAheadOfTime(std::vector> module, + CompileAheadOfTime(std::unique_ptr module_group, AotCompilationOptions const& options) override; se::Platform::Id PlatformId() const override; @@ -140,10 +141,10 @@ class NVPTXCompiler : public LLVMCompiler { tensorflow::condition_variable compilation_done_cv_; }; - // Don't even think about switching this to FlatMap; iterator stability is - // critical here. - std::unordered_map + // Don't even think about switching this to flat_hash_map; iterator stability + // is critical here. + absl::node_hash_map compilation_cache_ GUARDED_BY(mutex_); TF_DISALLOW_COPY_AND_ASSIGN(NVPTXCompiler); diff --git a/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc b/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc index b99d998c4d7df514c024b1f8d643d08c72059d0e..e0f3e84a4cb25792cf10d38fc529f3e638acf8e4 100644 --- a/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc @@ -96,7 +96,7 @@ Status OutfeedThunk::ExecuteOnStream( Status block_status = stream->BlockHostUntilDone(); if (!block_status.ok()) { return InternalError("Failed to complete data transfer on stream %p: %s", - stream, block_status.error_message().c_str()); + stream, block_status.error_message()); } VLOG(2) << "Outfeeding from GPU complete"; diff --git a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc index ca57cacb983bd2492a36dc462c09b357abb7ec37..8154d75d23a6d49153ccb6824402aff73f365617 100644 --- a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc @@ -40,7 +40,7 @@ ParallelLoopEmitter::ParallelLoopEmitter( ParallelLoopEmitter::ParallelLoopEmitter( const llvm_ir::ElementGenerator& target_element_generator, - tensorflow::gtl::ArraySlice target_arrays, + absl::Span target_arrays, const LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* b, int unroll_factor) : LoopEmitter(target_element_generator, target_arrays, b), diff --git a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h index cc7da2e73b681bb351e722cc3fb39f7746f45568..f32ea1ce4c4192f39851a6441c46663df3063724 100644 --- a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h @@ -47,11 +47,10 @@ class ParallelLoopEmitter : public llvm_ir::LoopEmitter { // // This is used in multi-output fusion. target_element_generator should // produce a struct with N elements, one for each of target_arrays. - ParallelLoopEmitter( - const llvm_ir::ElementGenerator& target_element_generator, - tensorflow::gtl::ArraySlice target_arrays, - const LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* b, - int unroll_factor = 1); + ParallelLoopEmitter(const llvm_ir::ElementGenerator& target_element_generator, + absl::Span target_arrays, + const LaunchDimensions& launch_dimensions, + llvm::IRBuilder<>* b, int unroll_factor = 1); ParallelLoopEmitter(const ParallelLoopEmitter&) = delete; ParallelLoopEmitter& operator=(const ParallelLoopEmitter&) = delete; diff --git a/tensorflow/compiler/xla/service/gpu/partition_assignment.cc b/tensorflow/compiler/xla/service/gpu/partition_assignment.cc index c927c5ee1666b6198d96750ff372ac83813a9df9..375f68a15957936151aee068582a714b62694af2 100644 --- a/tensorflow/compiler/xla/service/gpu/partition_assignment.cc +++ b/tensorflow/compiler/xla/service/gpu/partition_assignment.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include "absl/memory/memory.h" +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -26,7 +27,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/bits.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" namespace xla { @@ -34,9 +34,8 @@ namespace gpu { std::ostream& operator<<(std::ostream& out, const LaunchDimensions& launch_dims) { - out << tensorflow::strings::Printf("[block: %lld, thread: %lld]", - launch_dims.block_count(), - launch_dims.threads_per_block()); + out << absl::StrFormat("[block: %d, thread: %d]", launch_dims.block_count(), + launch_dims.threads_per_block()); return out; } @@ -63,13 +62,8 @@ LaunchDimensions CalculateLaunchDimensions( // // * = - auto threads_per_core = device_desc.threads_per_core_limit(); - auto blocks_per_core = device_desc.blocks_per_core_limit(); - int64 threads_per_block; - if (threads_per_core != 0 && blocks_per_core != 0) { - threads_per_block = device_desc.threads_per_core_limit() / - device_desc.blocks_per_core_limit(); - } else { + int64 threads_per_block = device_desc.threads_per_block_limit(); + if (threads_per_block == 0) { static std::atomic log_count{0}; if (log_count.fetch_add(1) < 8) { LOG(WARNING) << "Attempting to calculate launch dimensions for GPU " @@ -91,9 +85,9 @@ LaunchDimensions CalculateLaunchDimensions( } int64 block_count = CeilOfRatio(num_elements, threads_per_block); - VLOG(2) << tensorflow::strings::Printf( + VLOG(2) << absl::StrFormat( "Initialized the block count to ceil(# of elements / threads per " - "block) = ceil(%lld/%lld) = %lld", + "block) = ceil(%d/%d) = %d", num_elements, threads_per_block, block_count); return LaunchDimensions(block_count, threads_per_block); diff --git a/tensorflow/compiler/xla/service/gpu/stream_assignment.h b/tensorflow/compiler/xla/service/gpu/stream_assignment.h index c2df83aaa4347a9439798acc6cfc2ba0db995232..52d38b6f20e8d61e2d4966ad15a5583a9cd2e945 100644 --- a/tensorflow/compiler/xla/service/gpu/stream_assignment.h +++ b/tensorflow/compiler/xla/service/gpu/stream_assignment.h @@ -16,9 +16,9 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_STREAM_ASSIGNMENT_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_STREAM_ASSIGNMENT_H_ +#include "absl/container/flat_hash_map.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" -#include "tensorflow/core/lib/gtl/flatmap.h" namespace xla { namespace gpu { @@ -34,7 +34,7 @@ class StreamAssignment { private: int stream_count_ = 1; // At least the main stream. - tensorflow::gtl::FlatMap hlo_to_stream_number_; + absl::flat_hash_map hlo_to_stream_number_; }; // Assigns GPU streams to instructions in `module`. diff --git a/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc index 3f75d8b55959495017f1b08d61bd6e7b44bed27f..c4f43cc9a614283acb376b5f98e4976615b590ad 100644 --- a/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc +++ b/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc @@ -16,18 +16,19 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/stream_assignment.h" #include "absl/memory/memory.h" +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/strings/stringprintf.h" namespace xla { namespace gpu { -class StreamAssignmentTest : public HloTestBase { +class StreamAssignmentTest : public HloVerifiedTestBase { protected: std::unique_ptr CreateNewModule() { HloModuleConfig config; @@ -49,10 +50,10 @@ TEST_F(StreamAssignmentTest, SequentialMatMul) { /*parameter_number=*/1, f32_2x2_, /*name=*/"y")); HloInstruction* z = builder.AddInstruction(HloInstruction::CreateParameter( /*parameter_number=*/2, f32_2x2_, /*name=*/"z")); - HloInstruction* dot1 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, x, y)); - HloInstruction* dot2 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, dot1, z)); + HloInstruction* dot1 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, x, y)); + HloInstruction* dot2 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, dot1, z)); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build(dot2)); @@ -68,10 +69,10 @@ TEST_F(StreamAssignmentTest, ConcurrentMatMul) { /*parameter_number=*/0, f32_2x2_, /*name=*/"x")); HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter( /*parameter_number=*/1, f32_2x2_, /*name=*/"y")); - HloInstruction* dot1 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, x, y)); - HloInstruction* dot2 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, y, x)); + HloInstruction* dot1 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, x, y)); + HloInstruction* dot2 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, y, x)); HloInstruction* add = builder.AddInstruction( HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kAdd, dot1, dot2)); @@ -98,26 +99,26 @@ TEST_F(StreamAssignmentTest, LatticeMatMul) { params.reserve(6); for (int i = 0; i < 6; ++i) { params.push_back(builder.AddInstruction(HloInstruction::CreateParameter( - i, f32_2x2_, /*name=*/tensorflow::strings::Printf("param%d", i)))); + i, f32_2x2_, /*name=*/absl::StrFormat("param%d", i)))); } HloInstruction* d00 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, params[2], params[3])); - HloInstruction* d10 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, params[1], d00)); - HloInstruction* d11 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, d00, params[4])); - HloInstruction* d20 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, params[0], d10)); - HloInstruction* d21 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, d10, d11)); - HloInstruction* d22 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, d11, params[5])); - HloInstruction* d30 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, d20, d21)); - HloInstruction* d31 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, d21, d22)); - HloInstruction* d40 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, d30, d31)); + CreateCanonicalDot(f32_2x2_, params[2], params[3])); + HloInstruction* d10 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, params[1], d00)); + HloInstruction* d11 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d00, params[4])); + HloInstruction* d20 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, params[0], d10)); + HloInstruction* d21 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d10, d11)); + HloInstruction* d22 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d11, params[5])); + HloInstruction* d30 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d20, d21)); + HloInstruction* d31 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d21, d22)); + HloInstruction* d40 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d30, d31)); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build(d40)); diff --git a/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc b/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc index 05b305ea4cdfdbaeb42544b626a6b9990bb42f57..08ff52211af163fec39646ca6bf14da9d1b815e4 100644 --- a/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc +++ b/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc @@ -16,6 +16,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h" #include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/util.h" namespace xla { namespace gpu { @@ -53,8 +55,9 @@ StreamExecutorConvLayoutsToXlaLayouts(const ConvolutionDimensionNumbers& dnums, input_layout.push_back(dnums.input_feature_dimension()); break; default: - return tensorflow::errors::Internal("Invalid input layout: ", - DataLayoutString(input)); + return InternalError("Invalid input layout %s for conv with dnums %s", + DataLayoutString(input), + ConvolutionDimensionNumbersToString(dnums)); } std::vector filter_layout; @@ -74,8 +77,9 @@ StreamExecutorConvLayoutsToXlaLayouts(const ConvolutionDimensionNumbers& dnums, filter_layout.push_back(dnums.kernel_input_feature_dimension()); break; default: - return tensorflow::errors::Internal("Invalid filter layout: ", - FilterLayoutString(filter)); + return InternalError("Invalid filter layout %s for conv with dnums %s", + FilterLayoutString(filter), + ConvolutionDimensionNumbersToString(dnums)); } std::vector output_layout; @@ -95,8 +99,9 @@ StreamExecutorConvLayoutsToXlaLayouts(const ConvolutionDimensionNumbers& dnums, output_layout.push_back(dnums.output_feature_dimension()); break; default: - return tensorflow::errors::Internal("Invalid output layout: ", - DataLayoutString(output)); + return InternalError("Invalid output layout %s for conv with dnums %s", + DataLayoutString(output), + ConvolutionDimensionNumbersToString(dnums)); } return std::make_tuple(LayoutUtil::MakeLayoutFromMajorToMinor(input_layout), @@ -128,8 +133,9 @@ XlaConvLayoutsToStreamExecutorLayouts(const ConvolutionDimensionNumbers& dnums, } else if (LayoutUtil::Equal(input, nhwc_input)) { input_layout = DataLayout::kBatchYXDepth; } else { - return tensorflow::errors::Internal("Invalid input layout: ", - input.ShortDebugString()); + return InternalError("Invalid input layout %s for conv with dnums %s", + LayoutUtil::HumanString(input), + ConvolutionDimensionNumbersToString(dnums)); } FilterLayout filter_layout; @@ -138,8 +144,9 @@ XlaConvLayoutsToStreamExecutorLayouts(const ConvolutionDimensionNumbers& dnums, } else if (LayoutUtil::Equal(filter, nhwc_filter)) { filter_layout = FilterLayout::kOutputYXInput; } else { - return tensorflow::errors::Internal("Invalid filter layout: ", - filter.ShortDebugString()); + return InternalError("Invalid filter layout %s for conv with dnums %s", + LayoutUtil::HumanString(filter), + ConvolutionDimensionNumbersToString(dnums)); } DataLayout output_layout; @@ -148,8 +155,9 @@ XlaConvLayoutsToStreamExecutorLayouts(const ConvolutionDimensionNumbers& dnums, } else if (LayoutUtil::Equal(output, nhwc_output)) { output_layout = DataLayout::kBatchYXDepth; } else { - return tensorflow::errors::Internal("Invalid output layout: ", - output.ShortDebugString()); + return InternalError("Invalid output layout %s for conv with dnums %s", + LayoutUtil::HumanString(output), + ConvolutionDimensionNumbersToString(dnums)); } return std::make_tuple(input_layout, filter_layout, output_layout); diff --git a/tensorflow/compiler/xla/service/gpu/tests/BUILD b/tensorflow/compiler/xla/service/gpu/tests/BUILD index db4a33dc564b62b5fe54b725ea453a6fcbfb3287..d22ffc1754dfd43f9e5e0677553f26610f4b8112 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/BUILD +++ b/tensorflow/compiler/xla/service/gpu/tests/BUILD @@ -25,15 +25,17 @@ filegroup( ) load("//tensorflow:tensorflow.bzl", "tf_cc_test") +load( + "//tensorflow/core:platform/default/build_config_root.bzl", + "tf_cuda_tests_tags", +) cc_library( name = "gpu_codegen_test", testonly = True, srcs = ["gpu_codegen_test.cc"], hdrs = ["gpu_codegen_test.h"], - tags = [ - "requires-gpu-sm35", - ], + tags = tf_cuda_tests_tags(), deps = [ "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service:gpu_plugin", @@ -48,9 +50,7 @@ cc_library( tf_cc_test( name = "gpu_copy_test", srcs = ["gpu_copy_test.cc"], - tags = [ - "requires-gpu-sm35", - ], + tags = tf_cuda_tests_tags(), deps = [ ":gpu_codegen_test", "//tensorflow/compiler/xla:literal", @@ -67,9 +67,7 @@ tf_cc_test( tf_cc_test( name = "gpu_ftz_test", srcs = ["gpu_ftz_test.cc"], - tags = [ - "requires-gpu-sm35", - ], + tags = tf_cuda_tests_tags(), deps = [ ":gpu_codegen_test", "//tensorflow/core:test_main", @@ -79,9 +77,7 @@ tf_cc_test( tf_cc_test( name = "gpu_index_test", srcs = ["gpu_index_test.cc"], - tags = [ - "requires-gpu-sm35", - ], + tags = tf_cuda_tests_tags(), deps = [ ":gpu_codegen_test", "//tensorflow/compiler/xla:literal", @@ -102,9 +98,7 @@ tf_cc_test( tf_cc_test( name = "gpu_infeed_test", srcs = ["infeed_test.cc"], - tags = [ - "requires-gpu-sm35", - ], + tags = tf_cuda_tests_tags(), deps = [ ":gpu_codegen_test", "//tensorflow/compiler/xla:literal", @@ -125,9 +119,7 @@ tf_cc_test( tf_cc_test( name = "gpu_kernel_tiling_test", srcs = ["gpu_kernel_tiling_test.cc"], - tags = [ - "requires-gpu-sm35", - ], + tags = tf_cuda_tests_tags(), deps = [ ":gpu_codegen_test", "//tensorflow/compiler/xla/service:hlo", @@ -142,7 +134,7 @@ tf_cc_test( tf_cc_test( name = "gpu_ldg_test", srcs = ["gpu_ldg_test.cc"], - tags = ["requires-gpu-sm35"], + tags = tf_cuda_tests_tags(), deps = [ ":gpu_codegen_test", "//tensorflow/compiler/xla:literal", @@ -159,9 +151,7 @@ tf_cc_test( tf_cc_test( name = "gpu_noalias_test", srcs = ["gpu_noalias_test.cc"], - tags = [ - "requires-gpu-sm35", - ], + tags = tf_cuda_tests_tags(), deps = [ ":gpu_codegen_test", "//tensorflow/compiler/xla:literal", @@ -178,9 +168,7 @@ tf_cc_test( tf_cc_test( name = "gpu_fusion_test", srcs = ["gpu_fusion_test.cc"], - tags = [ - "requires-gpu-sm35", - ], + tags = tf_cuda_tests_tags(), deps = [ ":gpu_codegen_test", "//tensorflow/compiler/xla/service:hlo_module_config", @@ -194,9 +182,7 @@ tf_cc_test( tf_cc_test( name = "gpu_unrolling_test", srcs = ["gpu_unrolling_test.cc"], - tags = [ - "requires-gpu-sm35", - ], + tags = tf_cuda_tests_tags(), deps = [ ":gpu_codegen_test", "//tensorflow/compiler/xla/service:hlo_module_config", @@ -211,9 +197,7 @@ tf_cc_test( name = "gpu_alignment_test", testonly = True, srcs = ["gpu_alignment_test.cc"], - tags = [ - "requires-gpu-sm35", - ], + tags = tf_cuda_tests_tags(), deps = [ ":gpu_codegen_test", "//tensorflow/compiler/xla/service:gpu_plugin", @@ -225,3 +209,29 @@ tf_cc_test( "//tensorflow/core:test_main", ], ) + +tf_cc_test( + name = "cudnn_fused_conv_rewriter_test", + srcs = ["cudnn_fused_conv_rewriter_test.cc"], + tags = tf_cuda_tests_tags(), + deps = [ + ":gpu_codegen_test", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "@com_google_absl//absl/strings", + ], +) + +tf_cc_test( + name = "gpu_atomic_test", + srcs = ["gpu_atomic_test.cc"], + tags = tf_cuda_tests_tags(), + deps = [ + ":gpu_codegen_test", + "//tensorflow/compiler/xla/tests:filecheck", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) diff --git a/tensorflow/compiler/xla/service/gpu/tests/cudnn_fused_conv_rewriter_test.cc b/tensorflow/compiler/xla/service/gpu/tests/cudnn_fused_conv_rewriter_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..8bdb4c8080aabe8cc324291ad9fc28b01d4eaf35 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/tests/cudnn_fused_conv_rewriter_test.cc @@ -0,0 +1,282 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "absl/strings/str_replace.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace gpu { +namespace { + +class CudnnFusedConvRewriterTest : public HloTestBase { + protected: + string GetOptimizedHlo(absl::string_view hlo_string) { + return backend() + .compiler() + ->RunHloPasses(ParseHloString(hlo_string, GetModuleConfigForTest()) + .ConsumeValueOrDie(), + backend().default_stream_executor(), + backend().memory_allocator()) + .ConsumeValueOrDie() + ->ToString(); + } + + void TestMatchWithAllTypes(absl::string_view hlo_string) { + for (absl::string_view type : {"f16", "f32", "f64"}) { + const string hlo_with_new_type = + absl::StrReplaceAll(hlo_string, {{"TYPE", type}}); + const string optimized_hlo_string = GetOptimizedHlo(hlo_with_new_type); + EXPECT_EQ(absl::string_view::npos, + optimized_hlo_string.find("__cudnn$convForward")) + << optimized_hlo_string; + EXPECT_NE(absl::string_view::npos, + optimized_hlo_string.find("__cudnn$convBiasActivationForward")) + << optimized_hlo_string; + EXPECT_TRUE(RunAndCompare(hlo_with_new_type, ErrorSpec{0.01})) + << optimized_hlo_string; + } + } + + void TestNotMatchWithAllTypes(absl::string_view hlo_string) { + for (absl::string_view type : {"f16", "f32", "f64"}) { + const string hlo_with_new_type = + absl::StrReplaceAll(hlo_string, {{"TYPE", type}}); + string optimized_hlo = GetOptimizedHlo(hlo_with_new_type); + EXPECT_NE(absl::string_view::npos, + optimized_hlo.find("__cudnn$convForward")) + << optimized_hlo; + EXPECT_EQ(absl::string_view::npos, + optimized_hlo.find("__cudnn$convBiasActivationForward")) + << optimized_hlo; + } + } +}; + +TEST_F(CudnnFusedConvRewriterTest, TestConvOnly) { + // max(0, conv(x, w)); + TestMatchWithAllTypes(R"( + HloModule Test + + ENTRY Test { + zero = TYPE[] constant(0) + zeros = TYPE[1,32,9,9] broadcast(zero), dimensions={} + + input = TYPE[1,17,9,9] parameter(0) + filter = TYPE[3,3,17,32] parameter(1) + + conv = TYPE[1,32,9,9] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1 + ROOT relu = TYPE[1,32,9,9] maximum(zeros, conv) + })"); +} + +TEST_F(CudnnFusedConvRewriterTest, TestBias) { + // max(0, conv(x, w) + bias); + TestMatchWithAllTypes(R"( + HloModule Test + + ENTRY Test { + zero = TYPE[] constant(0) + zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={} + + input = TYPE[1,3,3,64] parameter(0) + filter = TYPE[3,3,64,64] parameter(1) + bias = TYPE[64] parameter(2) + + conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1 + broadcasted_bias = TYPE[1,3,3,64] broadcast(bias), dimensions={3} + add1 = TYPE[1,3,3,64] add(conv, broadcasted_bias) + ROOT relu = TYPE[1,3,3,64] maximum(zeros, add1) + })"); +} + +TEST_F(CudnnFusedConvRewriterTest, TestSideInputOnly) { + // max(0, conv(x, w) + side_input); + TestMatchWithAllTypes(R"( + HloModule Test + + ENTRY Test { + zero = TYPE[] constant(0) + zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={} + + input = TYPE[1,3,3,64] parameter(0) + filter = TYPE[3,3,64,64] parameter(1) + side_input = TYPE[1,3,3,64] parameter(2) + + conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1 + add1 = TYPE[1,3,3,64] add(conv, side_input) + ROOT relu = TYPE[1,3,3,64] maximum(zeros, add1) + })"); +} + +TEST_F(CudnnFusedConvRewriterTest, TestBiasAndSideInput) { + // max(0, conv(x, w) + side_input + bias); + TestMatchWithAllTypes(R"( + HloModule Test + + ENTRY Test { + zero = TYPE[] constant(0) + zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={} + + input = TYPE[1,3,3,64] parameter(0) + filter = TYPE[3,3,64,64] parameter(1) + side_input = TYPE[1,3,3,64] parameter(2) + bias = TYPE[64] parameter(3) + + conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1 + broadcasted_bias = TYPE[1,3,3,64] broadcast(bias), dimensions={3} + add1 = TYPE[1,3,3,64] add(conv, broadcasted_bias) + add2 = TYPE[1,3,3,64] add(add1, side_input) + ROOT relu = TYPE[1,3,3,64] maximum(zeros, add2) + })"); +} + +TEST_F(CudnnFusedConvRewriterTest, TestScaledConv) { + // max(0, 0.999994934 * conv(x, w)); + TestMatchWithAllTypes(R"( + HloModule Test + + ENTRY Test { + zero = TYPE[] constant(0) + zeros = TYPE[1,32,9,9] broadcast(zero), dimensions={} + alpha_conv_scalar = TYPE[] constant(0.999994934) + + input = TYPE[1,17,9,9] parameter(0) + filter = TYPE[3,3,17,32] parameter(1) + + conv = TYPE[1,32,9,9] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1 + alpha_conv = TYPE[1,32,9,9] broadcast(alpha_conv_scalar), dimensions={} + scaled_conv = TYPE[1,32,9,9] multiply(conv, alpha_conv) + ROOT relu = TYPE[1,32,9,9] maximum(zeros, scaled_conv) + })"); +} + +TEST_F(CudnnFusedConvRewriterTest, TestScaledConvAndSideInput) { + // max(0, conv(x, w) + 0.899994934 * side_input); + TestMatchWithAllTypes(R"( + HloModule Test + + ENTRY Test { + zero = TYPE[] constant(0) + zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={} + alpha_side_input_scalar = TYPE[] constant(0.899994934) + alpha_side_input = TYPE[1,3,3,64] broadcast(alpha_side_input_scalar), dimensions={} + + input = TYPE[1,3,3,64] parameter(0) + filter = TYPE[3,3,64,64] parameter(1) + side_input = TYPE[1,3,3,64] parameter(2) + + conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1 + scaled_side_input = TYPE[1,3,3,64] multiply(side_input, alpha_side_input) + add1 = TYPE[1,3,3,64] add(conv, scaled_side_input) + ROOT relu = TYPE[1,3,3,64] maximum(zeros, add1) + })"); +} + +TEST_F(CudnnFusedConvRewriterTest, TestScaledConvAndScaledSideInput) { + // max(0, 0.999994934 * conv(x, w) + 0.899994934 * side_input); + TestMatchWithAllTypes(R"( + HloModule Test + + ENTRY Test { + zero = TYPE[] constant(0) + zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={} + alpha_conv_scalar = TYPE[] constant(0.999994934) + alpha_conv = TYPE[1,3,3,64] broadcast(alpha_conv_scalar), dimensions={} + alpha_side_input_scalar = TYPE[] constant(0.899994934) + alpha_side_input = TYPE[1,3,3,64] broadcast(alpha_side_input_scalar), dimensions={} + + input = TYPE[1,3,3,64] parameter(0) + filter = TYPE[3,3,64,64] parameter(1) + side_input = TYPE[1,3,3,64] parameter(2) + + conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1 + scaled_conv = TYPE[1,3,3,64] multiply(conv, alpha_conv) + scaled_side_input = TYPE[1,3,3,64] multiply(side_input, alpha_side_input) + add1 = TYPE[1,3,3,64] add(scaled_conv, scaled_side_input) + ROOT relu = TYPE[1,3,3,64] maximum(zeros, add1) + })"); +} + +TEST_F(CudnnFusedConvRewriterTest, TestScaledConvAndScaledSideInputWithBias) { + // max(0, 0.999994934 * conv(x, w) + 0.899994934 * side_input + bias); + TestMatchWithAllTypes(R"( + HloModule Test + + ENTRY Test { + zero = TYPE[] constant(0) + zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={} + alpha_conv_scalar = TYPE[] constant(0.999994934) + alpha_conv = TYPE[1,3,3,64] broadcast(alpha_conv_scalar), dimensions={} + alpha_side_input_scalar = TYPE[] constant(0.899994934) + alpha_side_input = TYPE[1,3,3,64] broadcast(alpha_side_input_scalar), dimensions={} + + input = TYPE[1,3,3,64] parameter(0) + filter = TYPE[3,3,64,64] parameter(1) + side_input = TYPE[1,3,3,64] parameter(2) + bias = TYPE[64] parameter(3) + + conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1 + scaled_conv = TYPE[1,3,3,64] multiply(conv, alpha_conv) + scaled_side_input = TYPE[1,3,3,64] multiply(side_input, alpha_side_input) + broadcasted_bias = TYPE[1,3,3,64] broadcast(bias), dimensions={3} + add1 = TYPE[1,3,3,64] add(scaled_conv, broadcasted_bias) + add2 = TYPE[1,3,3,64] add(add1, scaled_side_input) + ROOT relu = TYPE[1,3,3,64] maximum(zeros, add2) + })"); +} + +TEST_F(CudnnFusedConvRewriterTest, TestMatchMaxZeroOnly) { + // max(0.1, conv(x, w)) shouldn't match. + TestNotMatchWithAllTypes(R"( + HloModule Test + + ENTRY Test { + point_one = TYPE[] constant(0.1) + point_ones = TYPE[1,32,9,9] broadcast(point_one), dimensions={} + + input = TYPE[1,17,9,9] parameter(0) + filter = TYPE[3,3,17,32] parameter(1) + + conv = TYPE[1,32,9,9] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1 + ROOT relu = TYPE[1,32,9,9] maximum(point_ones, conv) + })"); +} + +TEST_F(CudnnFusedConvRewriterTest, TestMatchBroadcastedBiasOnly) { + // max(0, conv(x, w) + side_input1 + side_input2) shouldn't match. + TestNotMatchWithAllTypes(R"( + HloModule Test + + ENTRY Test { + zero = TYPE[] constant(0) + zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={} + + input = TYPE[1,3,3,64] parameter(0) + filter = TYPE[3,3,64,64] parameter(1) + side_input1 = TYPE[1,3,3,64] parameter(2) + side_input2 = TYPE[1,3,3,64] parameter(3) + + conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1 + add1 = TYPE[1,3,3,64] add(conv, side_input2) + add2 = TYPE[1,3,3,64] add(add1, side_input1) + ROOT relu = TYPE[1,3,3,64] maximum(zeros, add2) + })"); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_atomic_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_atomic_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..6b18c4c63714b4b3c06d7fa85f4a7a75b8e9ae12 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_atomic_test.cc @@ -0,0 +1,58 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h" +#include "tensorflow/compiler/xla/tests/filecheck.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace gpu { +namespace { + +class GpuAtomicTest : public GpuCodegenTest {}; + +TEST_F(GpuAtomicTest, TestStore) { + const char* hlo_string = R"( + HloModule TensorFlowScatterV1 + + update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) + } + + ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + updates = s32[2,3] parameter(2) + ROOT scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 + } +)"; + + CompileAndVerifyIr(hlo_string, R"( +CHECK: store atomic{{.*}}unordered, align 4 +)"); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.cc index 0e84ec7e621fcd1778725dc2743d7a70fb01c47a..79e77d4c4d649020cf52ac25c220c3f90e8469b9 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.cc @@ -39,8 +39,7 @@ void GpuCodegenTest::CompileAndVerifyPtx(std::unique_ptr hlo_module, const string& pattern) { std::unique_ptr executable = std::move(CompileToExecutable(std::move(hlo_module)).ValueOrDie()); - string ptx_str = - std::string(static_cast(executable.get())->ptx()); + string ptx_str(static_cast(executable.get())->ptx()); StatusOr filecheck_result = RunFileCheck(ptx_str, pattern); ASSERT_TRUE(filecheck_result.ok()); EXPECT_TRUE(filecheck_result.ValueOrDie()); diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc index 4550f36fdfc097632fed4956fcd3e42ef8a919c5..780539c164277f14c2bd964024f7c3ca179f4ada 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc @@ -38,8 +38,7 @@ class GpuCopyTest : public GpuCodegenTest {}; TEST_F(GpuCopyTest, UseMemcpy) { HloComputation::Builder builder(TestName()); - std::unique_ptr literal = - LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + Literal literal = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); HloInstruction* constant = builder.AddInstruction( HloInstruction::CreateConstant(std::move(literal))); builder.AddInstruction(HloInstruction::CreateUnary( diff --git a/tensorflow/compiler/xla/service/gpu/tests/infeed_test.cc b/tensorflow/compiler/xla/service/gpu/tests/infeed_test.cc index 9072b30317d253fd6d50e9d98949cad4eaebfe7b..f8120a5fa00ce38644cd85c54d5ef65701be1eda 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/infeed_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/infeed_test.cc @@ -53,40 +53,40 @@ class InfeedTest : public ClientLibraryTestBase { }; TEST_F(InfeedTest, SingleInfeedR0Bool) { - TestInfeedRoundTrip(*LiteralUtil::CreateR0(true)); + TestInfeedRoundTrip(LiteralUtil::CreateR0(true)); } TEST_F(InfeedTest, SingleInfeedR1U32) { - TestInfeedRoundTrip(*LiteralUtil::CreateR1({1, 2, 3})); + TestInfeedRoundTrip(LiteralUtil::CreateR1({1, 2, 3})); } TEST_F(InfeedTest, SingleInfeedR2F32) { - TestInfeedRoundTrip(*LiteralUtil::CreateR2F32Linspace(0.0, 1.0, 128, 64)); + TestInfeedRoundTrip(LiteralUtil::CreateR2F32Linspace(0.0, 1.0, 128, 64)); } TEST_F(InfeedTest, SingleInfeedR3F32) { TestInfeedRoundTrip( - *LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, - {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}})); + LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, + {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}})); } TEST_F(InfeedTest, SingleInfeedR3F32DifferentLayout) { const Layout r3_dim0minor = LayoutUtil::MakeLayout({0, 1, 2}); const Layout r3_dim0major = LayoutUtil::MakeLayout({2, 1, 0}); - TestInfeedRoundTrip(*LiteralUtil::CreateR3WithLayout( + TestInfeedRoundTrip(LiteralUtil::CreateR3WithLayout( {{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}, r3_dim0minor)); - TestInfeedRoundTrip(*LiteralUtil::CreateR3WithLayout( + TestInfeedRoundTrip(LiteralUtil::CreateR3WithLayout( {{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}, r3_dim0major)); } TEST_F(InfeedTest, SingleInfeedR4S32) { - TestInfeedRoundTrip(*LiteralUtil::CreateR4( + TestInfeedRoundTrip(LiteralUtil::CreateR4( {{{{1, -2}, {-4, 5}, {6, 7}}, {{8, 9}, {10, 11}, {12, 13}}}, {{{10, 3}, {7, -2}, {3, 6}}, {{2, 5}, {-11, 5}, {-2, -5}}}})); } @@ -95,26 +95,26 @@ TEST_F(InfeedTest, SingleInfeedR4S32) { TEST_F(InfeedTest, LargeInfeed) { Array4D array(80, 100, 8, 128); array.FillIota(1.0f); - TestInfeedRoundTrip(*LiteralUtil::CreateR4FromArray4D(array)); + TestInfeedRoundTrip(LiteralUtil::CreateR4FromArray4D(array)); } TEST_F(InfeedTest, SingleInfeedTuple) { - TestInfeedRoundTrip( - *LiteralUtil::MakeTuple({LiteralUtil::CreateR1({1, 2, 3}).get(), - LiteralUtil::CreateR0(false).get()})); + TestInfeedRoundTrip(LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR1({1, 2, 3}), + LiteralUtil::CreateR0(false)})); } TEST_F(InfeedTest, SingleInfeedEmptyTuple) { - TestInfeedRoundTrip(*LiteralUtil::MakeTuple({})); + TestInfeedRoundTrip(LiteralUtil::MakeTuple({})); } // Tests that a large tuple infeed can be handled. TEST_F(InfeedTest, SingleInfeedLargeTuple) { Array4D array(40, 100, 8, 128); array.FillIota(1.0f); - TestInfeedRoundTrip(*LiteralUtil::MakeTuple( - {LiteralUtil::CreateR4FromArray4D(array).get(), - LiteralUtil::CreateR0(5).get()})); + TestInfeedRoundTrip(LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR4FromArray4D(array), + LiteralUtil::CreateR0(5)})); } } // namespace diff --git a/tensorflow/compiler/xla/service/gpu/tuple_thunk.h b/tensorflow/compiler/xla/service/gpu/tuple_thunk.h index 2d5735d6c40ccd26f0e527f1a02403910db4c812..dcdbf2cf3c2aa87cc11a3473a765cb405b50e2a6 100644 --- a/tensorflow/compiler/xla/service/gpu/tuple_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/tuple_thunk.h @@ -18,12 +18,12 @@ limitations under the License. #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/service/gpu/thunk.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" namespace xla { @@ -34,8 +34,7 @@ namespace gpu { // issue (b/31336476). class TupleThunk : public Thunk { public: - TupleThunk(tensorflow::gtl::ArraySlice - tuple_element_buffers, + TupleThunk(absl::Span tuple_element_buffers, const BufferAllocation::Slice& dest_buffer, const HloInstruction* hlo_instruction) : Thunk(Kind::kTuple, hlo_instruction), diff --git a/tensorflow/compiler/xla/service/gpu/while_thunk.cc b/tensorflow/compiler/xla/service/gpu/while_thunk.cc index 828fc2884bd7d58333d86c35a537f06467cf6e4a..c4754fe378960834e1157b0ff25c03c0fc4754c7 100644 --- a/tensorflow/compiler/xla/service/gpu/while_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/while_thunk.cc @@ -70,7 +70,7 @@ Status WhileThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, if (!block_status.ok()) { return InternalError( "Failed to complete all kernels launched on stream %p: %s", stream, - block_status.error_message().c_str()); + block_status.error_message()); } if (!condition_result) { diff --git a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc index 40183de96ee363996e6b0b883a78e7a8b5d13ab2..9a61f8ac5a62e38e687a93890eb33481a01d51c8 100644 --- a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc +++ b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc @@ -26,9 +26,6 @@ limitations under the License. namespace xla { namespace { -using ::testing::Eq; -using ::testing::HasSubstr; - class WhileTransformerTest : public HloTestBase { protected: WhileTransformerTest() diff --git a/tensorflow/compiler/xla/service/graphviz_example.cc b/tensorflow/compiler/xla/service/graphviz_example.cc index a2be89511babc23ebcd5cb40abee2a95d16dc451..ef70b688778df5115e2b5fe572d253a6948d076f 100644 --- a/tensorflow/compiler/xla/service/graphviz_example.cc +++ b/tensorflow/compiler/xla/service/graphviz_example.cc @@ -112,8 +112,11 @@ std::unique_ptr MakeBigGraph() { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - auto dot = builder.AddInstruction( - HloInstruction::CreateDot(vshape, clamp, param_v0, dot_dnums)); + PrecisionConfig precision_config; + precision_config.mutable_operand_precision()->Resize( + /*new_size=*/2, PrecisionConfig::DEFAULT); + auto dot = builder.AddInstruction(HloInstruction::CreateDot( + vshape, clamp, param_v0, dot_dnums, precision_config)); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({dot, param_s, clamp})); auto scalar = builder.AddInstruction( diff --git a/tensorflow/compiler/xla/service/heap_simulator.cc b/tensorflow/compiler/xla/service/heap_simulator.cc index 38c3982ebf170d5733d56a05106835d1eaa4f2e1..9220865867b770eebfb1ada8f31a5d24693a4b8d 100644 --- a/tensorflow/compiler/xla/service/heap_simulator.cc +++ b/tensorflow/compiler/xla/service/heap_simulator.cc @@ -18,24 +18,26 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/util.h" namespace xla { -using tensorflow::gtl::FlatMap; -using tensorflow::gtl::FlatSet; +using absl::flat_hash_map; +using absl::flat_hash_set; /*static*/ StatusOr HeapSimulator::MinimumMemoryForModule( - const SequentialHloOrdering::HloModuleSequence& module_sequence, + const HloSchedule& schedule, const LogicalBuffer::SizeFunction& size_function) { - if (module_sequence.empty()) { + if (schedule.empty()) { return 0; } - const HloModule* module = module_sequence.begin()->first->parent(); + const HloModule* module = schedule.module(); TF_ASSIGN_OR_RETURN(std::unique_ptr points_to_analysis, TuplePointsToAnalysis::Run(module)); @@ -47,17 +49,16 @@ StatusOr HeapSimulator::MinimumMemoryForModule( TF_ASSIGN_OR_RETURN( HeapSimulator::Result result, HeapSimulator::Run(absl::make_unique(), *module, - module_sequence, *points_to_analysis, size_function)); + schedule, *points_to_analysis, size_function)); return result.heap_size; } /*static*/ StatusOr HeapSimulator::MinimumMemoryForComputation( - const HloComputation& computation, - const std::vector& sequence, + const HloComputation& computation, const HloInstructionSequence& sequence, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, - const tensorflow::gtl::FlatMap* + const absl::flat_hash_map* memory_by_computation) { TF_ASSIGN_OR_RETURN( HeapSimulator::Result result, @@ -71,13 +72,13 @@ StatusOr HeapSimulator::MinimumMemoryForComputation( /*static*/ StatusOr HeapSimulator::Run( std::unique_ptr algorithm, const HloModule& module, - const SequentialHloOrdering::HloModuleSequence& module_sequence, + const HloSchedule& schedule, const TuplePointsToAnalysis& points_to_analysis, const BufferValue::SizeFunction& size_fn, const Options& options) { - HeapSimulator heap(std::move(algorithm), size_fn, options, &module_sequence); + HeapSimulator heap(std::move(algorithm), size_fn, options, &schedule); const HloComputation* entry_computation = module.entry_computation(); - const std::vector& instruction_sequence = - FindOrDie(module_sequence, entry_computation); + const HloInstructionSequence& instruction_sequence = + schedule.sequence(entry_computation); TF_RETURN_IF_ERROR(heap.RunComputation( *entry_computation, instruction_sequence, points_to_analysis)); return heap.Finish(); @@ -86,13 +87,13 @@ StatusOr HeapSimulator::Run( /*static*/ StatusOr HeapSimulator::Run( std::unique_ptr algorithm, const HloComputation& computation, - const std::vector& instruction_sequence, + const HloInstructionSequence& instruction_sequence, const TuplePointsToAnalysis& points_to_analysis, const BufferValue::SizeFunction& size_fn, const Options& options, - const tensorflow::gtl::FlatMap* + const absl::flat_hash_map* memory_by_computation) { HeapSimulator heap(std::move(algorithm), size_fn, options, - /*module_sequence=*/nullptr, memory_by_computation); + /*schedule=*/nullptr, memory_by_computation); TF_RETURN_IF_ERROR(heap.RunComputation(computation, instruction_sequence, points_to_analysis)); return heap.Finish(); @@ -102,7 +103,7 @@ StatusOr HeapSimulator::Run( // 'instruction_sequence'. Status HeapSimulator::RunComputation( const HloComputation& computation, - const std::vector& instruction_sequence, + const HloInstructionSequence& instruction_sequence, const TuplePointsToAnalysis& points_to_analysis) { VLOG(3) << "Computation:\n" << computation.ToString(); // The goal here is to minimize memory usage, assuming the given sequential @@ -116,8 +117,10 @@ Status HeapSimulator::RunComputation( // 'used_buffers' is the reverse map - it tracks which buffers were used by an // instruction, so that we can remove the instructions from a buffer's live // set after they are visited. - FlatMap> live_buffers; - FlatMap> used_buffers; + flat_hash_map> + live_buffers; + flat_hash_map> + used_buffers; auto add_user_to_buffer = [this, &live_buffers, &used_buffers]( const HloInstruction* user, const BufferValue* buffer) { @@ -133,7 +136,8 @@ Status HeapSimulator::RunComputation( // set of instructions that need to be visited contains all users of all // aliases, that is, all users of all instructions that have the buffer // contained in their points-to set. - for (const HloInstruction* instruction : instruction_sequence) { + for (const HloInstruction* instruction : + instruction_sequence.instructions()) { const PointsToSet& points_to = points_to_analysis.GetPointsToSet(instruction); const PointsToSet::BufferSet& buffer_set = points_to.CreateFlattenedSet(); @@ -166,7 +170,8 @@ Status HeapSimulator::RunComputation( std::vector dead_buffers_to_free; std::vector operand_buffers_to_free; - for (const HloInstruction* instruction : instruction_sequence) { + for (const HloInstruction* instruction : + instruction_sequence.instructions()) { const TuplePointsToAnalysis::BufferDefinitionVector& buffers_defined_by_instruction = points_to_analysis.GetBuffersDefinedByInstruction(instruction); @@ -212,7 +217,7 @@ Status HeapSimulator::RunComputation( VLOG(4) << " Removing user " << instruction->name() << " from buffer " << operand_buffer->ToString(); auto it = live_buffers.find(operand_buffer); - FlatSet* live_set = &it->second; + flat_hash_set* live_set = &it->second; live_set->erase(instruction); if (live_set->empty()) { live_buffers.erase(it); @@ -234,7 +239,8 @@ Status HeapSimulator::RunComputation( // that we should assign. // Make sure each buffer get reused at most once. - FlatSet reused_buffers; + flat_hash_set reused_buffers; + int64 alloc_size_by_instruction = 0; for (const BufferValue* buffer : buffers_defined_by_instruction) { if (IgnoreBuffer(buffer)) { continue; @@ -267,14 +273,15 @@ Status HeapSimulator::RunComputation( if (!shared) { VLOG(3) << " Allocating: " << buffer->ToString(); + alloc_size_by_instruction += size_fn_(*buffer); Alloc(buffer, instruction); } } // Account for the memory used by subcomputations when estimating the // current heap size. if (memory_by_computation_ != nullptr) { - algorithm_->AccountForSubcomputationMemory(instruction, - *memory_by_computation_); + algorithm_->AccountForSubcomputationMemory( + instruction, alloc_size_by_instruction, *memory_by_computation_); } // If all computations in the module have been scheduled, we can save memory @@ -285,14 +292,14 @@ Status HeapSimulator::RunComputation( // The order that the sub-computations are simulated does not affect // correctness; since the whole module has been scheduled, we know that the // sub-computations will never be run concurrently. - if (module_sequence_ != nullptr) { + if (schedule_ != nullptr) { if (instruction->opcode() == HloOpcode::kCall || instruction->opcode() == HloOpcode::kConditional || instruction->opcode() == HloOpcode::kWhile) { for (const HloComputation* called_computation : instruction->called_computations()) { - const std::vector& called_sequence = - FindOrDie(*module_sequence_, called_computation); + const HloInstructionSequence& called_sequence = + schedule_->sequence(called_computation); TF_RETURN_IF_ERROR(RunComputation( *called_computation, called_sequence, points_to_analysis)); } @@ -322,7 +329,7 @@ Status HeapSimulator::RunComputation( to_free.reserve(live_buffers.size()); for (const auto& buffer_pending : live_buffers) { const BufferValue* buffer = buffer_pending.first; - const FlatSet& pending = buffer_pending.second; + const flat_hash_set& pending = buffer_pending.second; CHECK_EQ(pending.size(), 1) << *buffer; CHECK(*pending.begin() == nullptr) << *buffer; to_free.push_back(buffer); @@ -343,16 +350,16 @@ Status HeapSimulator::RunComputation( HeapSimulator::HeapSimulator( std::unique_ptr algorithm, const BufferValue::SizeFunction& size_fn, const Options& options, - const SequentialHloOrdering::HloModuleSequence* module_sequence, - const tensorflow::gtl::FlatMap* + const HloSchedule* schedule, + const absl::flat_hash_map* memory_by_computation) : no_fragmentation_stats_(absl::make_unique()), algorithm_(std::move(algorithm)), size_fn_(size_fn), options_(options), - module_sequence_(module_sequence), + schedule_(schedule), memory_by_computation_(memory_by_computation) { - debug_trace_.set_whole_module_simulation(module_sequence_ != nullptr); + debug_trace_.set_whole_module_simulation(schedule_ != nullptr); } HeapSimulator::~HeapSimulator() {} @@ -380,10 +387,8 @@ void HeapSimulator::Alloc(const BufferValue* buffer, allocated_buffers_.insert(buffer); const int64 size = size_fn_(*buffer); - const HloInstruction* instruction_to_calc_aliasing = - memory_by_computation_ == nullptr ? nullptr : instruction; - algorithm_->Alloc(buffer, size, instruction_to_calc_aliasing); - no_fragmentation_stats_->Alloc(buffer, size, instruction_to_calc_aliasing); + algorithm_->Alloc(buffer, size); + no_fragmentation_stats_->Alloc(buffer, size); FillDebugTrace(HeapSimulatorTrace::Event::ALLOC, buffer, instruction, nullptr); } @@ -521,21 +526,9 @@ void NoFragmentationStatsHeap::Alloc(const BufferValue* buffer, int64 size) { } } -void NoFragmentationStatsHeap::Alloc(const BufferValue* buffer, int64 size, - const HloInstruction* instruction) { - // The output buffer of while/call/conditional is always aliased with the - // output buffer of the root instruction in the body. Don't double count. - if (instruction == nullptr || - (instruction->opcode() != HloOpcode::kWhile && - instruction->opcode() != HloOpcode::kCall && - instruction->opcode() != HloOpcode::kConditional)) { - Alloc(buffer, size); - } -} - void NoFragmentationStatsHeap::AccountForSubcomputationMemory( - const HloInstruction* instruction, - const tensorflow::gtl::FlatMap& + const HloInstruction* instruction, int64 alloc_size_by_instruction, + const absl::flat_hash_map& memory_by_computation) { // We only count the memory usage of the largest subcomputation, instead of // adding them all, because subcomputations won't execute in parallel. @@ -549,6 +542,14 @@ void NoFragmentationStatsHeap::AccountForSubcomputationMemory( } } } + if (max_subcomputation_bytes > 0 && + (instruction->opcode() == HloOpcode::kWhile || + instruction->opcode() == HloOpcode::kCall || + instruction->opcode() == HloOpcode::kConditional)) { + // The output buffer of while/call/conditional is always aliased with the + // output buffer of the root instruction in the body. Don't double count. + max_subcomputation_bytes -= alloc_size_by_instruction; + } max_heap_size_ = std::max(max_heap_size_, current_heap_size_ + max_subcomputation_bytes); } @@ -735,4 +736,209 @@ HeapSimulator::Result LazyBestFitHeap::Finish() { return result_; } +void GlobalDecreasingSizeBestFitHeap::Alloc(const BufferValue* buffer, + int64 size) { + // Degenerate case: 0-sized buffers are always allocated at offset 0. + if (size == 0) { + result_.chunk_map.emplace(buffer, Chunk{0, 0}); + return; + } + auto emplace_result = buffer_intervals_.emplace( + buffer, BufferInterval{buffer, size, current_time_, -1}); + DCHECK(emplace_result.second); + ++current_time_; +} + +void GlobalDecreasingSizeBestFitHeap::Free(const BufferValue* buffer, + int64 size) { + // Degenerate case: 0-sized buffers are always allocated at offset 0. + if (size == 0) { + return; + } + BufferInterval& buffer_interval = FindOrDie(buffer_intervals_, buffer); + DCHECK_EQ(buffer_interval.buffer, buffer); + DCHECK_EQ(buffer_interval.size, size); + DCHECK_EQ(buffer_interval.end, -1); + buffer_interval.end = current_time_; + ++current_time_; +} + +namespace { + +// Node in BufferIntervalTree that stores the alloc and free times of a buffer, +// and the chunk assigned to it. +struct BufferIntervalTreeNode { + // Alloc time. + int64 start; + // Free time. + int64 end; + // Maximum free time of all nodes in the subtree where this node is the root. + int64 subtree_end; + // Allocated chunk for the buffer. + HeapSimulator::Chunk chunk; + // Left child. + BufferIntervalTreeNode* left; + // Right child. + BufferIntervalTreeNode* right; +}; + +// An interval tree that can query buffers overlapping in time. +class BufferIntervalTree { + public: + explicit BufferIntervalTree(int capacity) : node_storage_(capacity) {} + + using Chunk = HeapSimulator::Chunk; + + // Adds a buffer to the interval tree, with the time interval and allocated + // chunk specified. + void Add(int64 start, int64 end, const Chunk& chunk) { + int index = node_count_; + DCHECK_LT(index, node_storage_.size()); + ++node_count_; + + node_storage_[index] = + BufferIntervalTreeNode{start, end, end, chunk, nullptr, nullptr}; + + if (index == 0) { + // This is root. + return; + } + + BufferIntervalTreeNode* parent = &node_storage_[0]; + while (true) { + parent->subtree_end = std::max(parent->subtree_end, end); + if (parent->start > start) { + if (parent->left == nullptr) { + parent->left = &node_storage_[index]; + return; + } + parent = parent->left; + } else { + if (parent->right == nullptr) { + parent->right = &node_storage_[index]; + return; + } + parent = parent->right; + } + } + } + + // Returns vector of allocated chunks that overlap with the given time + // interval. + std::vector ChunksOverlappingInTime(int64 start, int64 end) { + std::vector result; + if (node_count_ == 0) { + return result; + } + std::vector visiting_stack; + visiting_stack.push_back(&node_storage_[0]); + while (!visiting_stack.empty()) { + BufferIntervalTreeNode* top = visiting_stack.back(); + visiting_stack.pop_back(); + if (start > top->subtree_end) { + continue; + } + if (top->left != nullptr) { + visiting_stack.push_back(top->left); + } + if (top->start <= end && top->end >= start) { + result.push_back(top->chunk); + } + if (end < top->start) { + continue; + } + if (top->right != nullptr) { + visiting_stack.push_back(top->right); + } + } + return result; + } + + private: + int64 node_count_ = 0; + std::vector node_storage_; +}; + +} // namespace + +HeapSimulator::Result GlobalDecreasingSizeBestFitHeap::Finish() { + std::vector sorted_buffer_intervals; + for (auto& entry : buffer_intervals_) { + sorted_buffer_intervals.push_back(entry.second); + } + std::sort(sorted_buffer_intervals.begin(), sorted_buffer_intervals.end(), + [](const BufferInterval& x, const BufferInterval& y) { + if (x.size != y.size) { + return x.size > y.size; + } + if (x.end - x.start != y.end - y.start) { + return x.end - x.start > y.end - y.start; + } + return x.buffer->id() < y.buffer->id(); + }); + + BufferIntervalTree interval_tree(sorted_buffer_intervals.size()); + for (auto& buffer_interval : sorted_buffer_intervals) { + auto chunks_overlapping_in_time = interval_tree.ChunksOverlappingInTime( + buffer_interval.start, buffer_interval.end); + std::sort( + chunks_overlapping_in_time.begin(), chunks_overlapping_in_time.end(), + [](const Chunk& x, const Chunk& y) { return x.offset < y.offset; }); + + // Find the minimum free chunk that can hold this buffer. + Chunk min_fit_chunk{-1, INT64_MAX}; + auto use_free_chunk_if_smaller = [&](int64 free_offset, int64 free_size) { + if (free_size < buffer_interval.size) { + return; + } + + if (free_size < min_fit_chunk.size) { + min_fit_chunk = {free_offset, free_size}; + } + }; + + int64 offset = 0; + for (auto& chunk : chunks_overlapping_in_time) { + if (offset < chunk.offset) { + use_free_chunk_if_smaller(offset, chunk.offset - offset); + } + offset = + std::max(offset, RoundUpToNearest(chunk.chunk_end(), alignment_)); + } + use_free_chunk_if_smaller(offset, result_.heap_size - offset); + + if (min_fit_chunk.offset == -1) { + // Increase the heap size to fit in the last free chunk. + result_.heap_size = offset + buffer_interval.size; + min_fit_chunk = {offset, buffer_interval.size}; + } + + min_fit_chunk.size = buffer_interval.size; + const auto emplace_result = + result_.chunk_map.emplace(buffer_interval.buffer, min_fit_chunk); + DCHECK(emplace_result.second); + + interval_tree.Add(buffer_interval.start, buffer_interval.end, + min_fit_chunk); + } + return result_; +} + +HeapSimulator::Result ChooseBestHeapAlgorithm::Finish() { + DCHECK(!algorithms_.empty()); + std::vector results(algorithms_.size()); + int64 min_size = INT64_MAX; + int min_size_index = -1; + for (int i = 0; i < algorithms_.size(); ++i) { + results[i] = algorithms_[i]->Finish(); + if (results[i].heap_size < min_size) { + min_size = results[i].heap_size; + min_size_index = i; + } + } + + DCHECK_GE(min_size_index, 0); + return results[min_size_index]; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/heap_simulator.h b/tensorflow/compiler/xla/service/heap_simulator.h index af05bedee72d4878f83765e5a5c5baf61bd71ba2..dbbf43082f2c1d21f5ef42f53804bf0969903a58 100644 --- a/tensorflow/compiler/xla/service/heap_simulator.h +++ b/tensorflow/compiler/xla/service/heap_simulator.h @@ -21,16 +21,17 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/service/buffer_value.h" #include "tensorflow/compiler/xla/service/buffer_value_containers.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" +#include "tensorflow/compiler/xla/service/hlo_schedule.h" #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/core/lib/gtl/flatmap.h" -#include "tensorflow/core/lib/gtl/flatset.h" namespace xla { @@ -57,7 +58,7 @@ class HeapSimulator { // Result represents the result of the heap simulation. struct Result { // The assignment of buffers to chunks. - tensorflow::gtl::FlatMap chunk_map; + absl::flat_hash_map chunk_map; // The total size in bytes of the heap, containing all assigned chunks. int64 heap_size = 0; @@ -88,23 +89,22 @@ class HeapSimulator { // Returns the minimum memory required to compute an HLO module where all // computations have been scheduled (represented by the given - // module_sequence), assuming no fragmentation. + // schedule), assuming no fragmentation. static StatusOr MinimumMemoryForModule( - const SequentialHloOrdering::HloModuleSequence& module_sequence, + const HloSchedule& schedule, const LogicalBuffer::SizeFunction& size_function); // Returns the minimum memory required to compute the given computation, // assuming no fragmentation. static StatusOr MinimumMemoryForComputation( - const HloComputation& computation, - const std::vector& sequence, + const HloComputation& computation, const HloInstructionSequence& sequence, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, - const tensorflow::gtl::FlatMap* + const absl::flat_hash_map* memory_by_computation = nullptr); // Run the heap simulation with the given algorithm, assuming the given - // module_sequence, which must contain a topologically-consistent total + // schedule, which must contain a topologically-consistent total // ordering of all instructions within each computation. The result is invalid // if instructions are not run in exactly this sequence. // @@ -112,12 +112,12 @@ class HeapSimulator { // to running on a per-computation basis, since we can re-use buffer space for // called sub-computations. // - static StatusOr Run( - std::unique_ptr algorithm, const HloModule& module, - const SequentialHloOrdering::HloModuleSequence& module_sequence, - const TuplePointsToAnalysis& points_to_analysis, - const BufferValue::SizeFunction& size_fn, - const Options& options = Options()); + static StatusOr Run(std::unique_ptr algorithm, + const HloModule& module, + const HloSchedule& schedule, + const TuplePointsToAnalysis& points_to_analysis, + const BufferValue::SizeFunction& size_fn, + const Options& options = Options()); // Same as above, but runs on a single computation. The 'instruction_sequence' // must contain a topologically-consistent total ordering of all instructions @@ -126,29 +126,27 @@ class HeapSimulator { static StatusOr Run( std::unique_ptr algorithm, const HloComputation& computation, - const std::vector& instruction_sequence, + const HloInstructionSequence& instruction_sequence, const TuplePointsToAnalysis& points_to_analysis, const BufferValue::SizeFunction& size_fn, const Options& options = Options(), - const tensorflow::gtl::FlatMap* + const absl::flat_hash_map* memory_by_computation = nullptr); private: - // If 'module_sequence' is non-null, it is used to find kCall and kWhile + // If 'schedule' is non-null, it is used to find kCall and kWhile // sub-computations, and the heap simulation for those sub-computations will // be run recursively. I.e. the simulation is run over the whole module. - HeapSimulator( - std::unique_ptr algorithm, - const BufferValue::SizeFunction& size_fn, const Options& options, - const SequentialHloOrdering::HloModuleSequence* module_sequence = nullptr, - const tensorflow::gtl::FlatMap* - memory_by_computation = nullptr); + HeapSimulator(std::unique_ptr algorithm, + const BufferValue::SizeFunction& size_fn, + const Options& options, const HloSchedule* schedule = nullptr, + const absl::flat_hash_map* + memory_by_computation = nullptr); ~HeapSimulator(); - Status RunComputation( - const HloComputation& computation, - const std::vector& instruction_sequence, - const TuplePointsToAnalysis& points_to_analysis); + Status RunComputation(const HloComputation& computation, + const HloInstructionSequence& instruction_sequence, + const TuplePointsToAnalysis& points_to_analysis); bool IgnoreBuffer(const BufferValue* buffer) const; void Alloc(const BufferValue* buffer, const HloInstruction* instruction); @@ -169,12 +167,12 @@ class HeapSimulator { const std::unique_ptr algorithm_; const BufferValue::SizeFunction size_fn_; const Options options_; - // module_sequence_ is set by buffer assignment, and memory_by_computation_ is + // schedule_ is set by buffer assignment, and memory_by_computation_ is // set by hlo scheduling. Then, in RunComputation, we check both in order to // handle subcomputations. It would be good to unify the handling of // subcomputations, but it's not clear how. - const SequentialHloOrdering::HloModuleSequence* module_sequence_; - const tensorflow::gtl::FlatMap* + const HloSchedule* schedule_; + const absl::flat_hash_map* memory_by_computation_; // In addition to Alloc and Free, the heap simulator exposes a concept of @@ -195,12 +193,12 @@ class HeapSimulator { const BufferValue* canonical = nullptr; int64 refcount = 0; }; - tensorflow::gtl::FlatMap> + absl::flat_hash_map> shared_buffers_; // Hold some sets for error-checking the sequence of Alloc and Free calls. - tensorflow::gtl::FlatSet allocated_buffers_; - tensorflow::gtl::FlatSet freed_buffers_; + absl::flat_hash_set allocated_buffers_; + absl::flat_hash_set freed_buffers_; // Debugging information filled in while the heap simulator runs. HeapSimulatorTrace debug_trace_; @@ -220,12 +218,6 @@ class HeapAlgorithm { // Alloc allocates a buffer of 'size' bytes. virtual void Alloc(const BufferValue* buffer, int64 size) = 0; - // NoFragmentationStatsHeap overrides this method. - virtual void Alloc(const BufferValue* buffer, int64 size, - const HloInstruction* instruction) { - Alloc(buffer, size); - } - // Takes memory usage of subcomputations into account when calculating the // memory usage of a computation. Currently, we don't handle buffer aliasing // between computations entirely correctly. We are careful to not double count @@ -237,7 +229,9 @@ class HeapAlgorithm { // analysis, it's not worth making major changes to HeapSimulator now. virtual void AccountForSubcomputationMemory( const HloInstruction* instruction, - const tensorflow::gtl::FlatMap& + // The total number of bytes allocated by instruction. + int64 alloc_size_by_instruction, + const absl::flat_hash_map& memory_by_computation) {} // Free de-allocates a previously allocated buffer. @@ -259,12 +253,9 @@ class NoFragmentationStatsHeap : public HeapAlgorithm { void Alloc(const BufferValue* buffer, int64 size) override; - void Alloc(const BufferValue* buffer, int64 size, - const HloInstruction* instruction) override; - void AccountForSubcomputationMemory( - const HloInstruction* instruction, - const tensorflow::gtl::FlatMap& + const HloInstruction* instruction, int64 alloc_size_by_instruction, + const absl::flat_hash_map& memory_by_computation) override; void Free(const BufferValue* buffer, int64 size) override; @@ -353,6 +344,67 @@ class LazyBestFitHeap : public HeapAlgorithm { std::set free_; }; +// GlobalDecreasingSizeBestFitHeap collects the live intervals of all buffers, +// then allocates them in decreasing sizes regardless of the alloc/free time. It +// internally tracks the allocated buffers and their live intervals; when +// allocating a buffer, it finds the best-fit free chunk during its live +// interval. +class GlobalDecreasingSizeBestFitHeap : public HeapAlgorithm { + public: + GlobalDecreasingSizeBestFitHeap(int64 alignment) : alignment_(alignment) {} + ~GlobalDecreasingSizeBestFitHeap() override {} + + void Alloc(const BufferValue* buffer, int64 size) override; + void Free(const BufferValue* buffer, int64 size) override; + Result Finish() override; + + private: + int64 alignment_; + Result result_; + + // The current time represented as an integer. It increments by 1 at each + // Alloc or Free call. + int64 current_time_ = 0; + + // BufferInterval stores a buffer's size and time interval. + struct BufferInterval { + const BufferValue* buffer; + int64 size; + // Alloc time of the buffer. + int64 start; + // Free time of the buffer. + int64 end; + }; + absl::flat_hash_map buffer_intervals_; +}; + +// A heap algorithm that chooses the best results from other algorithms added to +// it. +class ChooseBestHeapAlgorithm : public HeapAlgorithm { + public: + ChooseBestHeapAlgorithm( + std::unique_ptr>> algorithms) + : algorithms_(std::move(*algorithms)) {} + ~ChooseBestHeapAlgorithm() override {} + + void Alloc(const BufferValue* buffer, int64 size) override { + for (auto& algorithm : algorithms_) { + algorithm->Alloc(buffer, size); + } + } + + void Free(const BufferValue* buffer, int64 size) override { + for (auto& algorithm : algorithms_) { + algorithm->Free(buffer, size); + } + } + + Result Finish() override; + + private: + std::vector> algorithms_; +}; + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HEAP_SIMULATOR_H_ diff --git a/tensorflow/compiler/xla/service/heap_simulator_test.cc b/tensorflow/compiler/xla/service/heap_simulator_test.cc index 5f85f145657b67634844c849447ef545a6dea468..e30e7667f3015bc7bfe67c65147a5016332780f7 100644 --- a/tensorflow/compiler/xla/service/heap_simulator_test.cc +++ b/tensorflow/compiler/xla/service/heap_simulator_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" #include "absl/memory/memory.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/buffer_value.h" @@ -29,13 +30,13 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_value.h" #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" -#include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/core/lib/core/status_test_util.h" namespace xla { namespace { -class MinimumMemoryForSequenceTest : public HloTestBase {}; +class MinimumMemoryForSequenceTest : public HloVerifiedTestBase {}; TEST_F(MinimumMemoryForSequenceTest, MultiComputation) { auto module = CreateNewModule(); @@ -85,12 +86,133 @@ TEST_F(MinimumMemoryForSequenceTest, MultiComputation) { return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8); }; - SequentialHloOrdering::HloModuleSequence module_sequence; - module_sequence[cond_computation] = {cond_param, cond_iter, cond_data, - cond_lt}; - module_sequence[body_computation] = {body_param}; - module_sequence[entry_computation] = {iter, data, tuple, while_op}; - EXPECT_EQ(56, HeapSimulator::MinimumMemoryForModule(module_sequence, size_fn) + HloSchedule schedule(module); + schedule.set_sequence(cond_computation, + {cond_param, cond_iter, cond_data, cond_lt}); + schedule.set_sequence(body_computation, {body_param}); + schedule.set_sequence(entry_computation, {iter, data, tuple, while_op}); + TF_ASSERT_OK(schedule.Verify()); + + EXPECT_EQ( + 56, + HeapSimulator::MinimumMemoryForModule(schedule, size_fn).ValueOrDie()); +} + +TEST_F(MinimumMemoryForSequenceTest, SubcomputationAccounting) { + // HloModule SubcomputationAccounting + + // %WhileBody (body_param: f32[4]) -> f32[4] { + // %body_param = f32[4]{0} parameter(0) + // %constant.1 = f32[4]{0} constant({1, 1, 1, 1}) + // ROOT %subtract = f32[4]{0} subtract(f32[4]{0} %body_param, f32[4]{0} + // %constant.1) + // } + + // %WhileCond (cond_param: f32[4]) -> pred[] { + // %cond_param = f32[4]{0} parameter(0) + // %slice = f32[1]{0} slice(f32[4]{0} %cond_param), slice={[0:1]} + // %reshape = f32[] reshape(f32[1]{0} %slice) + // %constant = f32[] constant(0) + // ROOT %not-equal-to = pred[] not-equal-to(f32[] %reshape, f32[] %constant) + // } + + // ENTRY %SubcomputationAccounting () -> f32[2,4] { + // %constant.3 = f32[2,4]{1,0} constant(f32[2,4] { { 1, 2, 3, 4 }, { 1, 2, + // 3, 4 } }) %transpose = f32[2,4]{1,0} transpose(f32[2,4]{1,0} + // %constant.3), dimensions={0,1} %constant.2 = f32[4]{0} constant({1, 1, 1, + // 1}) %while = f32[4]{0} while(f32[4]{0} %constant.2), + // condition=%WhileCond, body=%WhileBody %broadcast = f32[2,4]{1,0} + // broadcast(f32[4]{0} %while), dimensions={1} ROOT %add = f32[2,4]{1,0} + // add(f32[2,4]{1,0} %transpose, f32[2,4]{1,0} %broadcast) + // } + + auto module = CreateNewVerifiedModule(); + const Shape r0f32 = ShapeUtil::MakeShape(F32, {}); + const Shape r1f32 = ShapeUtil::MakeShape(F32, {4}); + const Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 4}); + + // reshape(slice(param)) != 0 + // Needs 5 bytes + auto cond_builder = HloComputation::Builder("WhileCond"); + HloInstruction* cond_param = cond_builder.AddInstruction( + HloInstruction::CreateParameter(0, r1f32, "cond_param")); + HloInstruction* slice = + cond_builder.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {1}), cond_param, {0}, {1}, {1})); + HloInstruction* reshape = + cond_builder.AddInstruction(HloInstruction::CreateReshape(r0f32, slice)); + HloInstruction* zero = cond_builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); + HloInstruction* cond_comparison = + cond_builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(PRED, {}), HloOpcode::kNe, reshape, zero)); + auto cond_computation = module->AddEmbeddedComputation(cond_builder.Build()); + + // param - 1 + // Needs 16 bytes + auto body_builder = HloComputation::Builder("WhileBody"); + HloInstruction* body_param = body_builder.AddInstruction( + HloInstruction::CreateParameter(0, r1f32, "body_param")); + HloInstruction* one_vector = + body_builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1({1, 1, 1, 1}))); + HloInstruction* subtract = + body_builder.AddInstruction(HloInstruction::CreateBinary( + r1f32, HloOpcode::kSubtract, body_param, one_vector)); + auto body_computation = module->AddEmbeddedComputation(body_builder.Build()); + + // transpose(matrix) + bcast(while) + auto builder = HloComputation::Builder(TestName()); + HloInstruction* while_init = + builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1({1, 1, 1, 1}))); + // Creates 16 bytes, ignoring subcomputations + HloInstruction* while_loop = + builder.AddInstruction(HloInstruction::CreateWhile( + r1f32, cond_computation, body_computation, while_init)); + + // Creates 32 bytes and frees 16 + HloInstruction* bcast = builder.AddInstruction( + HloInstruction::CreateBroadcast(r2f32, while_loop, {1})); + + HloInstruction* matrix = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR2( + {{1.0, 2.0, 3.0, 4.0}, {1.0, 2.0, 3.0, 4.0}}))); + // Creates 32 bytes + HloInstruction* transpose = builder.AddInstruction( + HloInstruction::CreateTranspose(r2f32, matrix, {0, 1})); + + // Creates 32 bytes and frees 64 + HloInstruction* add = builder.AddInstruction( + HloInstruction::CreateBinary(r2f32, HloOpcode::kAdd, transpose, bcast)); + + auto entry_computation = module->AddEntryComputation(builder.Build()); + + HloSchedule schedule(module.get()); + std::vector cond_vec = {cond_param, slice, reshape, zero, + cond_comparison}; + std::vector while_body_vec = {body_param, one_vector, + subtract}; + std::vector entry_comp_vec = {while_init, while_loop, bcast, + matrix, transpose, add}; + schedule.set_sequence(cond_computation, cond_vec); + schedule.set_sequence(body_computation, while_body_vec); + schedule.set_sequence(entry_computation, entry_comp_vec); + + auto size_fn = [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape()); + }; + absl::flat_hash_map memory_by_computation; + memory_by_computation[cond_computation] = 5; + memory_by_computation[body_computation] = 16; + std::unique_ptr points_to_analysis = + TuplePointsToAnalysis::Run(module.get()).ValueOrDie(); + + // HeapSimulator accounts for subcomputations. The output buffer is aliased, + // so we don't double count. + EXPECT_EQ(64, HeapSimulator::MinimumMemoryForComputation( + *entry_computation, schedule.sequence(entry_computation), + *points_to_analysis, size_fn, &memory_by_computation) .ValueOrDie()); } @@ -149,10 +271,11 @@ class HeapSimulatorTracker { auto zero_size = [](const BufferValue& buffer) { return 0; }; auto algorithm = absl::make_unique( absl::make_unique(&actual_calls_)); - result_ = HeapSimulator::Run( - std::move(algorithm), *module_->entry_computation(), - instruction_sequence, *points_to_analysis_, zero_size) - .ConsumeValueOrDie(); + result_ = + HeapSimulator::Run(std::move(algorithm), *module_->entry_computation(), + HloInstructionSequence(instruction_sequence), + *points_to_analysis_, zero_size) + .ConsumeValueOrDie(); } explicit HeapSimulatorTracker(const string& name) { @@ -168,11 +291,12 @@ class HeapSimulatorTracker { TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie(); // Construct the module sequence grouped by computation. - SequentialHloOrdering::HloModuleSequence module_sequence; - tensorflow::gtl::FlatMap reverse_position; + HloSchedule schedule(module_.get()); + absl::flat_hash_map reverse_position; for (int i = 0; i < full_module_sequence.size(); ++i) { const HloInstruction* instruction = full_module_sequence[i]; - module_sequence[instruction->parent()].push_back(instruction); + schedule.GetOrCreateSequence(instruction->parent()) + .push_back(instruction); reverse_position[instruction] = full_module_sequence.size() - i; } @@ -185,8 +309,8 @@ class HeapSimulatorTracker { }; auto algorithm = absl::make_unique( absl::make_unique(&actual_calls_)); - result_ = HeapSimulator::Run(std::move(algorithm), *module_, - module_sequence, *points_to_analysis_, size_fn) + result_ = HeapSimulator::Run(std::move(algorithm), *module_, schedule, + *points_to_analysis_, size_fn) .ConsumeValueOrDie(); } @@ -227,7 +351,7 @@ class HeapSimulatorTracker { HeapSimulator::Result result_; }; -class HeapSimulatorTest : public HloTestBase { +class HeapSimulatorTest : public HloVerifiedTestBase { protected: HeapSimulatorTest() {} ~HeapSimulatorTest() override {} @@ -366,8 +490,8 @@ TEST_F(HeapSimulatorTest, MultiplyDot) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - auto dot = builder.AddInstruction( - HloInstruction::CreateDot(f32vec4_, mul, paramY, dot_dnums)); + auto dot = builder.AddInstruction(HloInstruction::CreateDot( + f32vec4_, mul, paramY, dot_dnums, DefaultPrecisionConfig(2))); // The buffer for dot is the output, and it cannot be shared with the buffer // for mul, since dot isn't elementwise. @@ -402,8 +526,8 @@ TEST_F(HeapSimulatorTest, MultiplyDotAdd) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - auto dot = builder.AddInstruction( - HloInstruction::CreateDot(f32vec4_, mul, paramY, dot_dnums)); + auto dot = builder.AddInstruction(HloInstruction::CreateDot( + f32vec4_, mul, paramY, dot_dnums, DefaultPrecisionConfig(2))); auto add = builder.AddInstruction( HloInstruction::CreateBinary(f32vec4_, HloOpcode::kAdd, dot, paramA)); @@ -440,10 +564,10 @@ TEST_F(HeapSimulatorTest, MultiplyDotDot) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - auto dot0 = builder.AddInstruction( - HloInstruction::CreateDot(f32vec4_, mul, paramY, dot_dnums)); - auto dot1 = builder.AddInstruction( - HloInstruction::CreateDot(f32vec4_, dot0, paramY, dot_dnums)); + auto dot0 = builder.AddInstruction(HloInstruction::CreateDot( + f32vec4_, mul, paramY, dot_dnums, DefaultPrecisionConfig(2))); + auto dot1 = builder.AddInstruction(HloInstruction::CreateDot( + f32vec4_, dot0, paramY, dot_dnums, DefaultPrecisionConfig(2))); // The buffer for dot1 is the output. No buffers can be shared. The buffer // for mul is freed before the end, since it's no longer used after dot0 @@ -481,10 +605,10 @@ TEST_F(HeapSimulatorTest, MultiplyDotDotTuple) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - auto dot0 = builder.AddInstruction( - HloInstruction::CreateDot(f32vec4_, mul, paramY, dot_dnums)); - auto dot1 = builder.AddInstruction( - HloInstruction::CreateDot(f32vec4_, dot0, paramY, dot_dnums)); + auto dot0 = builder.AddInstruction(HloInstruction::CreateDot( + f32vec4_, mul, paramY, dot_dnums, DefaultPrecisionConfig(2))); + auto dot1 = builder.AddInstruction(HloInstruction::CreateDot( + f32vec4_, dot0, paramY, dot_dnums, DefaultPrecisionConfig(2))); auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({dot0, dot1})); @@ -1015,5 +1139,135 @@ TEST_F(LazyBestFitHeapTest, Alignment) { EXPECT_EQ(128, result.chunk_map.at(buffer_e_).offset); } +class GlobalDecreasingSizeBestFitHeapTest : public HeapAlgorithmTestBase {}; + +TEST_F(GlobalDecreasingSizeBestFitHeapTest, Empty) { + GlobalDecreasingSizeBestFitHeap heap(/*alignment=*/1); + const HeapSimulator::Result result = heap.Finish(); + EXPECT_EQ(0, result.heap_size); + EXPECT_EQ(0, result.chunk_map.size()); +} + +TEST_F(GlobalDecreasingSizeBestFitHeapTest, DecreasingSize) { + // space + // ^ + // | +---a---+ + // | +-------+ + // | +---c---+ + // | +-------+ + // | | b | + // | +-------+ + // | +-------+ + // | | | + // | | d | + // | +-------+ + // -----------------> time + GlobalDecreasingSizeBestFitHeap heap(/*alignment=*/1); + heap.Alloc(buffer_a_, 10); + heap.Alloc(buffer_b_, 30); + heap.Alloc(buffer_c_, 20); + heap.Alloc(buffer_d_, 40); + heap.Free(buffer_a_, 10); + heap.Free(buffer_b_, 30); + heap.Free(buffer_c_, 20); + heap.Free(buffer_d_, 40); + + const HeapSimulator::Result result = heap.Finish(); + EXPECT_EQ(100, result.heap_size); + EXPECT_EQ(10, result.chunk_map.at(buffer_a_).size); + EXPECT_EQ(30, result.chunk_map.at(buffer_b_).size); + EXPECT_EQ(20, result.chunk_map.at(buffer_c_).size); + EXPECT_EQ(40, result.chunk_map.at(buffer_d_).size); + + EXPECT_EQ(90, result.chunk_map.at(buffer_a_).offset); + EXPECT_EQ(40, result.chunk_map.at(buffer_b_).offset); + EXPECT_EQ(70, result.chunk_map.at(buffer_c_).offset); + EXPECT_EQ(0, result.chunk_map.at(buffer_d_).offset); +} + +TEST_F(GlobalDecreasingSizeBestFitHeapTest, DecreasingSizeWithAlignment) { + // space + // ^ + // | +-------+ + // | +---b---+ + // | +-------+ + // | | | + // | | d | + // | +---a---+ +-------+ + // | + // | +-------+ + // | | | + // | | c | + // | | | + // | +-------+ + // ---------------------> time + GlobalDecreasingSizeBestFitHeap heap(/*alignment=*/20); + heap.Alloc(buffer_a_, 10); + heap.Alloc(buffer_b_, 20); + heap.Alloc(buffer_c_, 50); + heap.Free(buffer_a_, 10); + heap.Alloc(buffer_d_, 40); + heap.Free(buffer_b_, 20); + heap.Free(buffer_c_, 50); + heap.Free(buffer_d_, 40); + + const HeapSimulator::Result result = heap.Finish(); + EXPECT_EQ(120, result.heap_size); + EXPECT_EQ(10, result.chunk_map.at(buffer_a_).size); + EXPECT_EQ(20, result.chunk_map.at(buffer_b_).size); + EXPECT_EQ(50, result.chunk_map.at(buffer_c_).size); + EXPECT_EQ(40, result.chunk_map.at(buffer_d_).size); + + EXPECT_EQ(60, result.chunk_map.at(buffer_a_).offset); + EXPECT_EQ(100, result.chunk_map.at(buffer_b_).offset); + EXPECT_EQ(0, result.chunk_map.at(buffer_c_).offset); + EXPECT_EQ(60, result.chunk_map.at(buffer_d_).offset); +} + +TEST_F(GlobalDecreasingSizeBestFitHeapTest, BestFit) { + // space + // ^ + // | +-------+ + // | +---b---+ + // | +-------+ + // | | d | + // | +--a--+ +-------+ + // | +-------+ + // | | | + // | | c | + // | +-------+ + // | +-------+ + // | | | + // | | e | + // | | | + // | +-------+ + // ---------------------> time + GlobalDecreasingSizeBestFitHeap heap(/*alignment=*/1); + heap.Alloc(buffer_a_, 10); + heap.Alloc(buffer_b_, 20); + heap.Alloc(buffer_c_, 40); + heap.Free(buffer_a_, 10); + heap.Alloc(buffer_d_, 30); + heap.Alloc(buffer_e_, 50); + heap.Free(buffer_b_, 20); + heap.Free(buffer_c_, 40); + heap.Free(buffer_d_, 30); + heap.Free(buffer_e_, 50); + + const HeapSimulator::Result result = heap.Finish(); + EXPECT_EQ(140, result.heap_size); + EXPECT_EQ(10, result.chunk_map.at(buffer_a_).size); + EXPECT_EQ(20, result.chunk_map.at(buffer_b_).size); + EXPECT_EQ(40, result.chunk_map.at(buffer_c_).size); + EXPECT_EQ(30, result.chunk_map.at(buffer_d_).size); + EXPECT_EQ(50, result.chunk_map.at(buffer_e_).size); + + EXPECT_EQ(90, result.chunk_map.at(buffer_a_).offset); + EXPECT_EQ(120, result.chunk_map.at(buffer_b_).offset); + EXPECT_EQ(50, result.chunk_map.at(buffer_c_).offset); + EXPECT_EQ(90, result.chunk_map.at(buffer_d_).offset); + EXPECT_EQ(0, result.chunk_map.at(buffer_e_).offset); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto index 821c599863839865c77a778ba569c56609fea0de..dbab62f847e8ca5e0b46dfd4162a0f4222640252 100644 --- a/tensorflow/compiler/xla/service/hlo.proto +++ b/tensorflow/compiler/xla/service/hlo.proto @@ -34,7 +34,7 @@ import "tensorflow/compiler/xla/xla_data.proto"; option cc_enable_arenas = true; // Serialization of HloInstruction. -// Next ID: 52 +// Next ID: 58 message HloInstructionProto { reserved 10; reserved "parameter_name"; @@ -124,9 +124,13 @@ message HloInstructionProto { // The string representation of the infeed configuration. bytes infeed_config = 27; - // Name of a global symbol to call, only present for kCustomCall. + // Name of a external target (eg, global symbol) to call, only present for + // kCustomCall. string custom_call_target = 28; + // Opaque string, only present for kCustomCall. + string custom_call_opaque = 53; + // Shape of outfeed request. xla.Shape outfeed_shape = 29; @@ -172,7 +176,21 @@ message HloInstructionProto { xla.ScatterDimensionNumbers scatter_dimension_numbers = 48; // Precision configuration for the instruction. Has backend-specific meaning. - xla.PrecisionConfigProto precision_config = 51; + xla.PrecisionConfig precision_config = 51; + + // Collective permute field. + repeated SourceTarget source_target_pairs = 52; + + // Sharding for kDomain instructions. + xla.OpSharding domain_entry_sharding = 54; + xla.OpSharding domain_exit_sharding = 55; + + // For custom call this indicates that the layouts are constrained. If + // constrain_layout is true then the 'shape' field must contain a layout, and + // 'operand_shapes_with_layout' must contain a shape with layout for each + // operand. + bool constrain_layout = 56; + repeated Shape operand_shapes_with_layout = 57; } // Serialization of HloComputation. @@ -196,6 +214,43 @@ message HloComputationProto { int64 root_id = 6; } +// Serialization of an HLO schedule. An HLO schedule contains a total order of +// instructions for each non-fusion computation in the module. +message HloScheduleProto { + message InstructionSequence { + repeated int64 instruction_ids = 1; + } + + // Map from computation id to sequence. + map sequences = 1; +} + +message HloInputOutputAliasProto { + // The following proto describes a pair of aliased an input + // (described by parameter number and a ShapeIndex of the parameter) + // and an output (described by a ShapeIndex of the root + // instruction). For example: + // + // entry = { + // output_shape_index={1}, + // parameter_number=0, + // parameter_shape_index={1, 2}, + // } + // + // This entry indicates that the first paremter's {1, 2} element is + // aliased with the {1} element of the root instruction. + message AliasEntryProto { + // ShapeIndex of the root hlo. + repeated int64 output_shape_index = 1; + // Number of the parameter in entry computation. + int64 parameter_number = 2; + // ShapeIndex of the parameter instruction. + repeated int64 parameter_shape_index = 3; + } + + repeated AliasEntryProto entries = 1; +} + // Serialization of HloModule. message HloModuleProto { string name = 1; @@ -206,21 +261,17 @@ message HloModuleProto { // callees appear before their callers. repeated HloComputationProto computations = 3; - // The program shape (with layout) of the entry computation. - xla.ProgramShape program_shape = 4; + // The host program shape (with layout) of the entry computation. + xla.ProgramShape host_program_shape = 4; // The id of this module. int64 id = 5; -} -// Serialization of HloOrdering. -message HloOrderingProto { - // NOTE: currently only sequential orderings are serialized. - message SequentialComputation { - string computation_name = 1; - repeated string instruction_names = 2; - } - repeated SequentialComputation sequential_computations = 1; + // The schedule for this module. + HloScheduleProto schedule = 7; + + // Describes alias information between inputs and outputs. + HloInputOutputAliasProto input_output_alias = 8; } // Serialization of LogicalBuffer. @@ -302,6 +353,13 @@ message HeapSimulatorTrace { bool whole_module_simulation = 2; } +// An abstraction representing a set of HLO module built to run concurrently +// across different devices. +message HloModuleGroupProto { + string name = 1; + repeated HloModuleProto hlo_modules = 2; +} + // Serialization of BufferAssignment. message BufferAssignmentProto { // Alias represents a source LogicalBuffer, and the buffer location that @@ -319,8 +377,10 @@ message BufferAssignmentProto { // Grouping message that contains all of the information above. message HloProto { + reserved 2; + reserved "hlo_ordering"; + HloModuleProto hlo_module = 1; - HloOrderingProto hlo_ordering = 2; BufferAssignmentProto buffer_assignment = 3; } diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc index 0986da65cbd3d550ecfa01212364518aba651d86..cf8e6594cbe5ffd28ca75dd5006e8817f1e8581c 100644 --- a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc @@ -20,6 +20,8 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/map_util.h" @@ -57,8 +59,9 @@ class BufferValueMap { // construction process. using BufferNumber = int64; - explicit BufferValueMap(const HloDataflowAnalysis& dataflow) - : dataflow_(dataflow) { + explicit BufferValueMap(HloModule* module, + const HloDataflowAnalysis& dataflow) + : module_(module), dataflow_(dataflow) { buffers_.reserve(dataflow_.values().size()); value_to_buffer_number_.reserve(dataflow_.values().size()); for (const HloValue* value : dataflow_.values()) { @@ -119,7 +122,7 @@ class BufferValueMap { } // Return a set of all the values in the given buffer. - const tensorflow::gtl::FlatSet& GetValuesInBuffer( + const absl::flat_hash_set& GetValuesInBuffer( BufferNumber buffer_number) const { return buffers_.at(buffer_number); } @@ -142,7 +145,7 @@ class BufferValueMap { // Move the given value into the given buffer. void MoveValueToBuffer(const HloValue& value, BufferNumber buffer_number) { BufferNumber old_buffer_number = value_to_buffer_number_.at(&value); - tensorflow::gtl::FlatSet& old_value_set = + absl::flat_hash_set& old_value_set = buffers_.at(old_buffer_number); old_value_set.erase(&value); if (old_value_set.empty()) { @@ -169,6 +172,42 @@ class BufferValueMap { return value_to_buffer_number_.at(&value); } + void ComputeInputOutputAliasedBuffers( + const HloValue& value, std::vector* aliased_buffers) { + // Get parameter value from an aliased_input object. + const auto get_parameter_value = + [this](const std::pair& aliased_input) + -> const HloValue& { + int64 param_number = aliased_input.first; + const ShapeIndex& param_index = aliased_input.second; + return dataflow_.GetUniqueValueAt( + module_->entry_computation()->parameter_instruction(param_number), + param_index); + }; + + // If the value shows up in a root instruction, alias it with parameter + // intruction. + for (const HloPosition& pos : value.positions()) { + if (pos.instruction == module_->entry_computation()->root_instruction()) { + ShapeIndex output_index = pos.index; + + auto aliased_input = + module_->input_output_alias_config().GetAliasedParameter( + output_index); + if (aliased_input) { + aliased_buffers->push_back( + GetBufferForValue(get_parameter_value(*aliased_input))); + } + } + } + + // If the value is parameter instruction itself, alias it with itself. + if (value.instruction()->opcode() == HloOpcode::kParameter && + value.instruction()->parent() == module_->entry_computation()) { + aliased_buffers->push_back(GetBufferForValue(value)); + } + } + void ComputeWhileAliasedBuffers(const HloValue& value, std::vector* aliased_buffers) { VLOG(3) << "Compute kWhile aliases"; @@ -276,6 +315,7 @@ class BufferValueMap { VLOG(2) << "Use of value " << value.ToShortString() << ": " << use; } std::vector aliased_buffers; + ComputeInputOutputAliasedBuffers(value, &aliased_buffers); ComputeWhileAliasedBuffers(value, &aliased_buffers); ComputeConditionalAliasedBuffers(value, &aliased_buffers); // Uniquify aliased buffers. @@ -286,17 +326,17 @@ class BufferValueMap { return aliased_buffers; } + HloModule* module_; + // Dataflow analysis used to construct the buffer map. const HloDataflowAnalysis& dataflow_; // A map containing the set of values contained in each buffer. - tensorflow::gtl::FlatMap> + absl::flat_hash_map> buffers_; // A map indicating which buffer each value is contained in. - tensorflow::gtl::FlatMap - value_to_buffer_number_; + absl::flat_hash_map value_to_buffer_number_; // The buffer number of the next buffer to be created. BufferNumber next_buffer_number_ = 0; @@ -352,7 +392,7 @@ bool HloAliasAnalysis::InstructionBuffersAreAmbiguous( bool HloAliasAnalysis::InstructionBuffersAreDistinct( const HloInstruction* instruction) const { - tensorflow::gtl::FlatSet buffers_seen; + absl::flat_hash_set buffers_seen; for (const auto& pair : dataflow_analysis_->GetInstructionValueSet(instruction)) { const HloValueSet& value_set = pair.second; @@ -461,7 +501,7 @@ StatusOr> HloAliasAnalysis::Run( /*bitcast_defines_value=*/false, fusion_can_share_buffer)); - BufferValueMap buffer_map(alias_analysis->dataflow_analysis()); + BufferValueMap buffer_map(module, alias_analysis->dataflow_analysis()); buffer_map.MergeAliasedBuffers(); // Create a vector of HloBuffers, one for each set of values in the diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis.h b/tensorflow/compiler/xla/service/hlo_alias_analysis.h index 1fea544730c27efdaa260f55ea81c163165f7ed5..372f99ff01c786a503e9fc2a1ba96fb4abf75b4c 100644 --- a/tensorflow/compiler/xla/service/hlo_alias_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_alias_analysis.h @@ -20,6 +20,8 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/service/hlo_buffer.h" #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -29,7 +31,6 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/macros.h" namespace xla { @@ -110,7 +111,7 @@ class HloAliasAnalysis { std::unique_ptr dataflow_analysis_; // A map indicating which buffer a value is contained in. - tensorflow::gtl::FlatMap value_to_buffer_; + absl::flat_hash_map value_to_buffer_; // A lazily constructed vector containing all HloBuffers sorted by // HloBuffer::Id. diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc index da94ab5346e5628b4a603b3ac2d84071904d1e65..5c8d97b2d15e15d15cb8014a7d25b37437ce8aec 100644 --- a/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc @@ -28,7 +28,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/logging.h" @@ -39,15 +39,17 @@ namespace { using ::testing::UnorderedElementsAre; -class HloAliasAnalysisTest : public HloTestBase { +class HloAliasAnalysisTest : public HloVerifiedTestBase { protected: - HloAliasAnalysisTest() : module_(CreateNewModule()) {} + HloAliasAnalysisTest() : HloVerifiedTestBase() { + module_ = CreateNewModule(); + } // Run alias analysis on the member module. For convenience returns a // reference to the generated analysis stored in analysis_. HloAliasAnalysis& RunAnalysis() { hlo_graph_dumper::MaybeDumpHloModule(*module_, "Before alias analysis"); - analysis_ = HloAliasAnalysis::Run(module_.get(), + analysis_ = HloAliasAnalysis::Run(module_, /*fusion_can_share_buffer=*/nullptr) .ConsumeValueOrDie(); return *analysis_; @@ -91,7 +93,7 @@ class HloAliasAnalysisTest : public HloTestBase { // never occurs, but HLO graphs with interference can be explicitly // constructed. bool AnyValuesInSameBufferInterfere() { - DependencyHloOrdering ordering(module_.get()); + DependencyHloOrdering ordering(module_); for (const HloBuffer& buffer : analysis_->buffers()) { for (const HloValue* value_a : buffer.values()) { for (const HloValue* value_b : buffer.values()) { @@ -108,7 +110,7 @@ class HloAliasAnalysisTest : public HloTestBase { return false; } - std::unique_ptr module_; + HloModule* module_; std::unique_ptr analysis_; const Shape scalar_shape_ = ShapeUtil::MakeShape(F32, {}); @@ -215,6 +217,181 @@ TEST_F(HloAliasAnalysisTest, NondistinctTuple) { EXPECT_FALSE(AnyValuesInSameBufferInterfere()); } +TEST_F(HloAliasAnalysisTest, ParametersWithAliasing) { + const Shape tuple_shape = + ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); + + auto builder = HloComputation::Builder(TestName()); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "p0")); + auto gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0)); + auto gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1)); + + auto negate0 = builder.AddInstruction( + HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, gte0)); + auto negate1 = builder.AddInstruction( + HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, gte1)); + + auto tuple = + builder.AddInstruction(HloInstruction::CreateTuple({negate0, negate1})); + module_->AddEntryComputation(builder.Build()); + TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias( + /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0})); + TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias( + /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1})); + + // Cannot alias an output twice. + ASSERT_IS_NOT_OK(module_->input_output_alias_config().SetUpAlias( + /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{0})); + + const HloAliasAnalysis& analysis = RunAnalysis(); + + EXPECT_EQ(analysis.GetUniqueBufferAt(gte0), + analysis.GetUniqueBufferAt(tuple, /*index=*/{0})); + + EXPECT_EQ(analysis.GetUniqueBufferAt(gte1), + analysis.GetUniqueBufferAt(tuple, /*index=*/{1})); +} + +TEST_F(HloAliasAnalysisTest, ParametersWithCrossAliasing) { + // parameter 0 aliased with output 1 and parameter 1 aliased with output 0. + // + // (p0 , p1) + // \ / + // \ / + // alias X + // / \ + // / \ + // (p0 , p1) + const Shape tuple_shape = + ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); + + auto builder = HloComputation::Builder(TestName()); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "p0")); + auto gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0)); + auto gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1)); + auto tuple = + builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1})); + module_->AddEntryComputation(builder.Build()); + TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias( + /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{1})); + TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias( + /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{0})); + + // Cannot alias an output twice. + ASSERT_IS_NOT_OK(module_->input_output_alias_config().SetUpAlias( + /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1})); + + const HloAliasAnalysis& analysis = RunAnalysis(); + + // Every Ops in this graph are aliased with each other. + EXPECT_EQ(analysis.GetUniqueBufferAt(gte0), + analysis.GetUniqueBufferAt(tuple, /*index=*/{0})); + EXPECT_EQ(analysis.GetUniqueBufferAt(gte0), + analysis.GetUniqueBufferAt(tuple, /*index=*/{1})); + + EXPECT_EQ(analysis.GetUniqueBufferAt(gte1), + analysis.GetUniqueBufferAt(tuple, /*index=*/{0})); + EXPECT_EQ(analysis.GetUniqueBufferAt(gte1), + analysis.GetUniqueBufferAt(tuple, /*index=*/{1})); +} + +TEST_F(HloAliasAnalysisTest, InputOutputAliasingWithWhile) { + // Test a simple single while instruction can be aliased with input and output + // of the computation. + // + // body((F32[], F32[]) %tuple_param): + // %add = Add(%tuple_param{0}, %tuple_param{1}) + // return Tuple(%tuple_param{0}, %add) + // + // condition((F32[], F32[]) %tuple_param): + // return Constant(false) + // + // entry: + // %param1 = param1 + // %while = While(%param1, body, condition) + // %while_1 = GTE(%while, 0) + // %while_2 = GTE(%while, 1) + // %negate_1 = Negate(%while_1) + // %negate_2 = Negate(%while_2) + // return Tuple(negate_1, negate_2) + // + const Shape tuple_shape = + ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); + + // Element 0 passes transparently through the body. + auto body_builder = HloComputation::Builder("body"); + auto body_param = body_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "param")); + auto body_element_0 = body_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0)); + auto body_element_1 = body_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1)); + auto add = body_builder.AddInstruction(HloInstruction::CreateBinary( + scalar_shape_, HloOpcode::kAdd, body_element_0, body_element_1)); + auto body_tuple = body_builder.AddInstruction( + HloInstruction::CreateTuple({body_element_0, add})); + HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build()); + + // Condition computation trivially returns a constant "false". + auto cond_builder = HloComputation::Builder("condition"); + auto cond_param = cond_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "param")); + cond_builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + HloComputation* condition = + module_->AddEmbeddedComputation(cond_builder.Build()); + + auto builder = HloComputation::Builder(TestName()); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "p0")); + + auto xla_while = builder.AddInstruction( + HloInstruction::CreateWhile(tuple_shape, condition, body, param)); + auto while_element_1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, xla_while, 0)); + auto while_element_2 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, xla_while, 1)); + auto negate_1 = builder.AddInstruction(HloInstruction::CreateUnary( + scalar_shape_, HloOpcode::kNegate, while_element_1)); + auto negate_2 = builder.AddInstruction(HloInstruction::CreateUnary( + scalar_shape_, HloOpcode::kNegate, while_element_2)); + auto tuple = + builder.AddInstruction(HloInstruction::CreateTuple({negate_1, negate_2})); + module_->AddEntryComputation(builder.Build()); + TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias( + /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0})); + TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias( + /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1})); + + const HloAliasAnalysis& analysis = RunAnalysis(); + + EXPECT_THAT( + GetValuesInBuffer(analysis.GetUniqueBufferAt(xla_while, /*index=*/{1})), + UnorderedElementsAre(GetValueDefinedAt(param, {1}), + GetValueDefinedAt(xla_while, /*index=*/{1}), + GetValueDefinedAt(body_param, {1}), + GetValueDefinedAt(cond_param, {1}), + GetValueDefinedAt(add), + GetValueDefinedAt(negate_2))); + + EXPECT_THAT( + analysis.GetUniqueBufferAt(xla_while, /*index=*/{1}).ComputePositions(), + UnorderedElementsAre( + HloPosition{param, {1}}, HloPosition{xla_while, {1}}, + HloPosition{while_element_2, {}}, HloPosition{body_param, {1}}, + HloPosition{body_element_1, {}}, HloPosition{add, {}}, + HloPosition{body_tuple, {1}}, HloPosition{tuple, {1}}, + HloPosition{cond_param, {1}}, HloPosition{negate_2, {}})); + + EXPECT_FALSE(AnyValuesInSameBufferInterfere()); +} + TEST_F(HloAliasAnalysisTest, SingleCall) { // Test a single call of a subcomputation. The subcomputation adds its two // array-shaped parameters. @@ -461,7 +638,7 @@ TEST_F(HloAliasAnalysisTest, SequentialWhiles) { module_->AddEntryComputation(builder.Build()); FlattenCallGraph flattener; - TF_ASSERT_OK(flattener.Run(module_.get()).status()); + TF_ASSERT_OK(flattener.Run(module_).status()); const HloAliasAnalysis& analysis = RunAnalysis(); @@ -835,7 +1012,7 @@ TEST_F(HloAliasAnalysisTest, BitcastInterference) { const HloAliasAnalysis& analysis = RunAnalysis(); - DependencyHloOrdering ordering(module_.get()); + DependencyHloOrdering ordering(module_); EXPECT_FALSE(analysis.HasLiveRangeInterference(ordering)); } @@ -877,24 +1054,26 @@ TEST_F(HloAliasAnalysisTest, WhileInterference) { { // Dependency ordering should interfere because the negate and while are // unordered. - DependencyHloOrdering ordering(module_.get()); + DependencyHloOrdering ordering(module_); EXPECT_TRUE(analysis.HasLiveRangeInterference(ordering)); } // For a sequential order, if there is interference iff the negate is after // the while. - SequentialHloOrdering::HloModuleSequence sequence; - sequence[body] = {body_param, body_root}; - sequence[condition] = {cond_param, cond_root}; + HloSchedule schedule(module_); + schedule.set_sequence(body, {body_param, body_root}); + schedule.set_sequence(condition, {cond_param, cond_root}); { - sequence[entry] = {init, xla_while, negate, entry_root}; - SequentialHloOrdering ordering(module_.get(), sequence); + schedule.set_sequence(entry, {init, xla_while, negate, entry_root}); + TF_ASSERT_OK(schedule.Verify()); + SequentialHloOrdering ordering(schedule); EXPECT_TRUE(analysis.HasLiveRangeInterference(ordering)); } { - sequence[entry] = {init, negate, xla_while, entry_root}; - SequentialHloOrdering ordering(module_.get(), sequence); + schedule.set_sequence(entry, {init, negate, xla_while, entry_root}); + TF_ASSERT_OK(schedule.Verify()); + SequentialHloOrdering ordering(schedule); EXPECT_FALSE(analysis.HasLiveRangeInterference(ordering)); } } diff --git a/tensorflow/compiler/xla/service/hlo_buffer.cc b/tensorflow/compiler/xla/service/hlo_buffer.cc index 6c11a073b74c61e44dfe81a32261ae78ae7b46fb..9c3aa0e64d119c2560f4955d0bcb492519fa52a2 100644 --- a/tensorflow/compiler/xla/service/hlo_buffer.cc +++ b/tensorflow/compiler/xla/service/hlo_buffer.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_set.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/map_util.h" @@ -28,7 +29,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/logging.h" namespace xla { diff --git a/tensorflow/compiler/xla/service/hlo_buffer.h b/tensorflow/compiler/xla/service/hlo_buffer.h index 4873463b2ea4fee3ee39dff31fc3429a4998142f..a88c87e46c8100571aff24f70a2a19fe8ce71ebc 100644 --- a/tensorflow/compiler/xla/service/hlo_buffer.h +++ b/tensorflow/compiler/xla/service/hlo_buffer.h @@ -84,7 +84,7 @@ class HloBuffer { return a->id() == b->id(); } - HloBuffer(Id id, tensorflow::gtl::ArraySlice values) + HloBuffer(Id id, absl::Span values) : id_(id), values_(values.begin(), values.end()) {} // Return the unique identifier for this HloBuffer. diff --git a/tensorflow/compiler/xla/service/hlo_clone_context.h b/tensorflow/compiler/xla/service/hlo_clone_context.h index 658643b427a9625fac1166151a89cbd669f817d5..24910ca07bf7c991d31875704b5dd918ed04fe6f 100644 --- a/tensorflow/compiler/xla/service/hlo_clone_context.h +++ b/tensorflow/compiler/xla/service/hlo_clone_context.h @@ -18,8 +18,8 @@ limitations under the License. #include +#include "absl/container/flat_hash_map.h" #include "tensorflow/compiler/xla/map_util.h" -#include "tensorflow/core/lib/gtl/flatmap.h" namespace xla { @@ -73,12 +73,12 @@ class HloCloneContext { return FindOrDie(computations_, old_computation); } - const tensorflow::gtl::FlatMap& + const absl::flat_hash_map& cloned_instructions() const { return instructions_; } - const tensorflow::gtl::FlatMap& + const absl::flat_hash_map& cloned_computations() const { return computations_; } @@ -86,10 +86,8 @@ class HloCloneContext { private: HloModule* module_; string suffix_; - tensorflow::gtl::FlatMap - instructions_; - tensorflow::gtl::FlatMap - computations_; + absl::flat_hash_map instructions_; + absl::flat_hash_map computations_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index cf95b112d7c69b4f098f703699bb2a418d380801..b0f7cd91ad1db0a59c09cfbfc1885813dc57e01e 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -24,6 +24,8 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "absl/strings/numbers.h" #include "absl/strings/str_cat.h" @@ -39,7 +41,6 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/logging.h" namespace xla { @@ -122,30 +123,6 @@ HloInstruction* HloComputation::AddParameter( return instructions_.back().get(); } -namespace { - -// Returns the new name for a fusion parameter when we change its number. -// -// Fusion parameters are named foo.param_1, bar.param_2, etc. We are -// renumbering the parameters, so replace the final number in the name with -// the updated value. -string RenameFusionParameter(const string& original_name, int64 new_param_no) { - const string param_underscore = ".param_"; - size_t index = original_name.rfind(param_underscore); - if (index == string::npos) { - return original_name; - } - string after_param = original_name.substr(index + param_underscore.size()); - int64 numeric_suffix; - if (absl::SimpleAtoi(after_param, &numeric_suffix)) { - return StrCat(original_name.substr(0, index + param_underscore.size()), - new_param_no); - } - return original_name; -} - -} // namespace - Status HloComputation::RemoveParameter(int64 param_no) { CHECK_GE(param_no, 0); CHECK_LT(param_no, param_instructions_.size()); @@ -158,11 +135,9 @@ Status HloComputation::RemoveParameter(int64 param_no) { while (param_no < param_instructions_.size()) { param_instruction = param_instructions_[param_no]; - string param_name = - RenameFusionParameter(param_instruction->name(), param_no); HloInstruction* new_instr = AddInstructionInternal(HloInstruction::CreateParameter( - param_no, param_instruction->shape(), param_name)); + param_no, param_instruction->shape(), StrCat("param_", param_no))); TF_RETURN_IF_ERROR(param_instruction->ReplaceAllUsesWith(new_instr)); param_instructions_[param_no] = new_instr; TF_RETURN_IF_ERROR(RemoveInstruction(param_instruction)); @@ -186,11 +161,9 @@ Status HloComputation::RemoveUnusedParameters() { if (removed > 0) { const int64 param_no = i - removed; - string param_name = - RenameFusionParameter(param_instruction->name(), param_no); - HloInstruction* new_instr = - AddInstructionInternal(HloInstruction::CreateParameter( - param_no, param_instruction->shape(), param_name)); + HloInstruction* new_instr = AddInstructionInternal( + HloInstruction::CreateParameter(param_no, param_instruction->shape(), + StrCat("param_", param_no))); TF_RETURN_IF_ERROR(param_instruction->ReplaceAllUsesWith(new_instr)); param_instructions_[param_no] = new_instr; TF_RETURN_IF_ERROR(RemoveInstruction(param_instruction)); @@ -242,7 +215,7 @@ Status HloComputation::RemoveInstructionAndUnusedOperands( if (removed.count(item) != 0 || item->user_count() != 0 || item == root_instruction() || !IsRemovable(item) || - item->HasSideEffect()) { + (item->HasSideEffect() && item != instruction)) { continue; } for (int i = 0; i < item->operand_count(); ++i) { @@ -272,18 +245,19 @@ Status HloComputation::RemoveInstruction(HloInstruction* instruction) { << "instruction " << instruction->name() << " has control successors and cannot be removed"; - TF_RET_CHECK(instruction_iterators_.count(instruction) != 0); - auto inst_it = instruction_iterators_.at(instruction); - (*inst_it)->set_parent(nullptr); - instructions_.erase(inst_it); + auto inst_it = instruction_iterators_.find(instruction); + TF_RET_CHECK(inst_it != instruction_iterators_.end()); + (*inst_it->second)->set_parent(nullptr); + instructions_.erase(inst_it->second); + instruction_iterators_.erase(inst_it); return Status::OK(); } -void HloComputation::set_root_instruction( - HloInstruction* new_root_instruction) { +void HloComputation::set_root_instruction(HloInstruction* new_root_instruction, + bool accept_different_shape) { // The shape of the root (ignoring layout) is an invariant of the computation // for non-fusion cases. - if (!IsFusionComputation()) { + if (!IsFusionComputation() && !accept_different_shape) { CHECK(ShapeUtil::Compatible(new_root_instruction->shape(), root_instruction_->shape())) << new_root_instruction->shape() << " is incompatible with " @@ -304,10 +278,9 @@ void HloComputation::set_root_instruction( namespace { // Helper which builds a post order of the HLO call graph. -void ComputeComputationPostOrder( - HloComputation* computation, - tensorflow::gtl::FlatSet* visited, - std::vector* post_order) { +void ComputeComputationPostOrder(HloComputation* computation, + absl::flat_hash_set* visited, + std::vector* post_order) { if (visited->insert(computation).second) { for (auto* instruction : computation->instructions()) { for (HloComputation* called_computation : @@ -319,12 +292,12 @@ void ComputeComputationPostOrder( } } -enum State { kVisiting, kVisited }; +} // namespace -void ComputeInstructionPostOrder( - std::map> channel_dependency_map, +void HloComputation::ComputeInstructionPostOrder( + const HloComputation::ChannelDependencyMap& channel_dependency_map, std::vector* post_order, HloInstruction* root, - tensorflow::gtl::FlatMap* visited) { + absl::flat_hash_map* visited) const { std::vector dfs_stack; dfs_stack.push_back(root); while (!dfs_stack.empty()) { @@ -362,20 +335,22 @@ void ComputeInstructionPostOrder( // dependencies. switch (current->opcode()) { case HloOpcode::kRecvDone: { - const auto& dependencies = - channel_dependency_map[current->channel_id()]; - for (HloInstruction* op : dependencies) { - dfs_stack.emplace_back(op); + auto it = channel_dependency_map.find(current->channel_id()); + if (it != channel_dependency_map.end()) { + for (HloInstruction* op : it->second) { + dfs_stack.emplace_back(op); + } } break; } case HloOpcode::kCrossReplicaSum: { auto all_reduce_id = current->all_reduce_id(); if (all_reduce_id) { - const auto& dependencies = - channel_dependency_map[all_reduce_id.value()]; - for (HloInstruction* op : dependencies) { - dfs_stack.emplace_back(op); + auto it = channel_dependency_map.find(all_reduce_id.value()); + if (it != channel_dependency_map.end()) { + for (HloInstruction* op : it->second) { + dfs_stack.emplace_back(op); + } } } break; @@ -386,11 +361,9 @@ void ComputeInstructionPostOrder( } } -} // namespace - -std::map> +HloComputation::ChannelDependencyMap HloComputation::ComputeChannelDependencies() const { - std::map> channel_dependency_map; + ChannelDependencyMap channel_dependency_map; for (const auto& instruction : instructions_) { switch (instruction->opcode()) { case HloOpcode::kSend: { @@ -421,7 +394,7 @@ std::vector HloComputation::MakeInstructionPostOrder() const { std::vector post_order; post_order.reserve(instruction_count()); std::vector trace_instructions; - tensorflow::gtl::FlatMap visited; + absl::flat_hash_map visited; for (auto& instruction : instructions_) { if (instruction->opcode() == HloOpcode::kTrace) { // Trace instructions aren't handled by the DFS visitor. Add trace @@ -442,7 +415,7 @@ std::vector HloComputation::MakeInstructionPostOrder() const { std::vector HloComputation::MakeEmbeddedComputationsList() const { - tensorflow::gtl::FlatSet visited; + absl::flat_hash_set visited; std::vector post_order; // To avoid special handling of this computation, cast away const of @@ -464,6 +437,14 @@ std::vector HloComputation::MakeEmbeddedComputationsList() } string HloComputation::ToString(const HloPrintOptions& options) const { + return ToString(options, MakeInstructionPostOrder()); +} + +string HloComputation::ToString( + const HloPrintOptions& options, + absl::Span instruction_order) const { + CHECK_EQ(instruction_order.size(), instruction_count()); + std::ostringstream s; for (int i = 0; i < options.indent_amount(); i++) { s << " "; @@ -486,7 +467,9 @@ string HloComputation::ToString(const HloPrintOptions& options) const { new_options.set_indent_amount(options.indent_amount() + 1) .set_is_in_nested_computation(true); CanonicalNameMap name_map; - for (const HloInstruction* instruction : MakeInstructionPostOrder()) { + for (const HloInstruction* instruction : instruction_order) { + CHECK_EQ(this, instruction->parent()); + for (int i = 0; i < new_options.indent_amount(); i++) { s << " "; } @@ -522,9 +505,9 @@ HloComputationProto HloComputation::ToProto() const { /* static */ StatusOr> HloComputation::CreateFromProto( const HloComputationProto& proto, - const tensorflow::gtl::FlatMap& computation_map) { - tensorflow::gtl::FlatMap instruction_map; - tensorflow::gtl::FlatMap to_proto_id; + const absl::flat_hash_map& computation_map) { + absl::flat_hash_map instruction_map; + absl::flat_hash_map to_proto_id; std::vector> instructions; int64 parameter_count = 0; for (const HloInstructionProto& instruction_proto : proto.instructions()) { @@ -552,13 +535,37 @@ HloComputation::CreateFromProto( return to_proto_id[a.get()] < to_proto_id[b.get()]; }); - return absl::WrapUnique(new HloComputation(proto.name(), parameter_count, - &instructions, root, - /*fusion_instruction=*/nullptr)); + TF_RETURN_IF_ERROR([&]() -> Status { + std::vector parameters_seen(parameter_count); + int parameters_seen_count = 0; + for (auto& instruction : instructions) { + if (instruction->opcode() == HloOpcode::kParameter) { + int64 param_no = instruction->parameter_number(); + TF_RET_CHECK(param_no >= 0 && param_no < parameter_count) + << "Invalid parameter number. Expected [0, " << parameter_count + << "), got " << param_no; + TF_RET_CHECK(!parameters_seen[param_no]) + << "Parameter number " << param_no + << " already allocated in this computation"; + parameters_seen[param_no] = true; + parameters_seen_count++; + } + } + TF_RET_CHECK(parameters_seen_count == parameter_count) + << "Not all parameters in range [0, " << parameter_count + << ") were referenced"; + return Status::OK(); + }()); + + auto computation = absl::WrapUnique( + new HloComputation(proto.name(), parameter_count, &instructions, root, + /*fusion_instruction=*/nullptr)); + computation->unique_id_ = proto.id(); + return std::move(computation); } void HloComputation::FuseInstructionsInto( - tensorflow::gtl::ArraySlice instructions_to_fuse, + absl::Span instructions_to_fuse, HloInstruction* fusion_instruction) { CHECK_EQ(HloOpcode::kFusion, fusion_instruction->opcode()); HloInstruction* root = instructions_to_fuse.front(); @@ -577,7 +584,7 @@ void HloComputation::FuseInstructionsInto( } HloInstruction* HloComputation::CreateFusionInstruction( - tensorflow::gtl::ArraySlice instructions_to_fuse, + absl::Span instructions_to_fuse, HloInstruction::FusionKind fusion_kind) { HloInstruction* root = instructions_to_fuse.front(); HloInstruction* fusion_instruction = AddInstruction( @@ -625,16 +632,15 @@ StatusOr HloComputation::DeepCopyInstruction( if (instruction->parent() != this) { return FailedPrecondition( "Can't deep copy instruction %s: instruction is not in computation %s", - instruction->name().c_str(), name().c_str()); + instruction->name(), name()); } if (indices_to_copy != nullptr && !ShapeUtil::Compatible(instruction->shape(), indices_to_copy->shape())) { return FailedPrecondition( "Can't deep copy instruction %s: given shape tree of indices to copy " "has incompatible shapes: %s vs. %s", - instruction->name().c_str(), - ShapeUtil::HumanString(instruction->shape()).c_str(), - ShapeUtil::HumanString(indices_to_copy->shape()).c_str()); + instruction->name(), ShapeUtil::HumanString(instruction->shape()), + ShapeUtil::HumanString(indices_to_copy->shape())); } ShapeIndex index; @@ -664,7 +670,7 @@ StatusOr HloComputation::DeepCopyInstructionWithCustomCopier( if (instruction->parent() != this) { return FailedPrecondition( "Can't deep copy instruction %s: instruction is not in computation %s", - instruction->name().c_str(), name().c_str()); + instruction->name(), name()); } ShapeIndex index; return DeepCopyHelper(instruction, &index, copy_leaf); @@ -747,16 +753,19 @@ std::unique_ptr HloComputation::ComputeReachability() switch (hlo->opcode()) { case HloOpcode::kRecvDone: { - const auto& dependencies = channel_dependency_map[hlo->channel_id()]; - absl::c_copy(dependencies, std::back_inserter(inputs)); + auto it = channel_dependency_map.find(hlo->channel_id()); + if (it != channel_dependency_map.end()) { + absl::c_copy(it->second, std::back_inserter(inputs)); + } break; } case HloOpcode::kCrossReplicaSum: { auto all_reduce_id = hlo->all_reduce_id(); if (all_reduce_id) { - const auto& dependencies = - channel_dependency_map[all_reduce_id.value()]; - absl::c_copy(dependencies, std::back_inserter(inputs)); + auto it = channel_dependency_map.find(all_reduce_id.value()); + if (it != channel_dependency_map.end()) { + absl::c_copy(it->second, std::back_inserter(inputs)); + } } break; } @@ -902,13 +911,14 @@ std::unique_ptr HloComputation::Clone( return CloneWithReplacements( /*replacements=*/std::unordered_map>(), - context, suffix); + /*extras=*/{}, context, suffix); } std::unique_ptr HloComputation::CloneWithReplacements( std::unordered_map> replacements, - HloCloneContext* context, const string& suffix) { + absl::Span extras, HloCloneContext* context, + const string& suffix) { std::unique_ptr context_ptr; if (context == nullptr) { context_ptr = absl::make_unique(parent(), suffix); @@ -930,6 +940,9 @@ std::unique_ptr HloComputation::CloneWithReplacements( VLOG(1) << "Cloning " << name() << " --> " << suffix << "\n"; std::vector postorder; + for (HloInstruction* instr : extras) { + postorder.push_back(instr); + } for (HloInstruction* instr : MakeInstructionPostOrder()) { if (HloInstruction* replacement = replace(instr)) { postorder.push_back(replacement); diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index 8d9b69497737312280a8d3c421e1f20ee346051c..dec96d11a93cf56d3c40a6bb7882ffb7336aeeb0 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -25,6 +25,9 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/iterator_util.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" @@ -39,9 +42,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/lib/gtl/flatmap.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" @@ -128,15 +128,18 @@ class HloComputation { // users. Instruction is deallocated with this call. Status RemoveInstruction(HloInstruction* instruction); - // Remove an instruction from the computation and also transitively any - // operand that has no users post removing an instruction. The instruction - // must have no users. Instruction is deallocated with this call. + // Remove an instruction (including side effecting ones) from the computation + // and also transitively any operand that has no side effect and no users post + // removing an instruction. The instruction must have no users. Instruction is + // deallocated with this call. Status RemoveInstructionAndUnusedOperands(HloInstruction* instruction); // Set the root of the computation to the given instruction. The instruction - // must have already been added to the computation and have the same shape as - // the result of the computation for non fusion computations. - void set_root_instruction(HloInstruction* new_root_instruction); + // must have already been added to the computation. In addition it must have + // the same shape as the result of the computation for non fusion + // computations, except if accept_different_shape is set to true. + void set_root_instruction(HloInstruction* new_root_instruction, + bool accept_different_shape = false); // Return the root instruction of the computation. The root instruction is the // instruction which produces the output of the computation. @@ -170,6 +173,11 @@ class HloComputation { string ToString() const { return ToString(HloPrintOptions()); } string ToString(const HloPrintOptions& options) const; + // Overload which accepts an order to emit the instructions in. + string ToString( + const HloPrintOptions& options, + absl::Span instruction_order) const; + // Returns a serialized representation of this computation. HloComputationProto ToProto() const; @@ -181,7 +189,7 @@ class HloComputation { // calls. static StatusOr> CreateFromProto( const HloComputationProto& proto, - const tensorflow::gtl::FlatMap& computation_map); + const absl::flat_hash_map& computation_map); // Gets the instructions in this computation. // @@ -220,7 +228,7 @@ class HloComputation { void UpdateReachabilityThroughInstruction( const HloInstruction* instruction, HloReachabilityMap* reachability_map); - int64 instruction_count() const { return instructions_.size(); } + int64 instruction_count() const { return instruction_iterators_.size(); } // Creates and returns a list of the embedded computations called by this // computation. This includes all embedded computations called directly or @@ -237,7 +245,7 @@ class HloComputation { // removed if they have no uses after fusion (this is necessarily true for at // least the root). HloInstruction* CreateFusionInstruction( - tensorflow::gtl::ArraySlice instructions_to_fuse, + absl::Span instructions_to_fuse, HloInstruction::FusionKind fusion_kind); // Create a deep copy of the given instruction and return the instruction @@ -326,10 +334,13 @@ class HloComputation { // // If replacements maps a key to nullptr, we remove that instruction from the // new computation. + // If additional instructions are used by instructions in replacement map, + // they must be passed in post-order in the extras span. std::unique_ptr CloneWithReplacements( std::unordered_map> replacements, - HloCloneContext* context = nullptr, const string& suffix = "clone"); + absl::Span extras, HloCloneContext* context = nullptr, + const string& suffix = "clone"); // Returns true if the given instruction can be removed from the computation. // Parameter instructions cannot be removed without violating invariants of @@ -385,7 +396,7 @@ class HloComputation { // // Pre-condition: fusion_instruction's opcode is kFusion. void FuseInstructionsInto( - tensorflow::gtl::ArraySlice instructions_to_fuse, + absl::Span instructions_to_fuse, HloInstruction* fusion_instruction); // Internal helper for recursive copying of an instruction. Creates and @@ -403,8 +414,15 @@ class HloComputation { // instructions. For send&recv pairs it means the send instruction and for // cross-replica-sum the union of the dependencies for all participating // instructions. - std::map> ComputeChannelDependencies() - const; + using ChannelDependencyMap = + absl::flat_hash_map>; + ChannelDependencyMap ComputeChannelDependencies() const; + + enum VisitState { kVisiting, kVisited }; + void ComputeInstructionPostOrder( + const HloComputation::ChannelDependencyMap& channel_dependency_map, + std::vector* post_order, HloInstruction* root, + absl::flat_hash_map* visited) const; string name_; int64 unique_id_; @@ -422,7 +440,7 @@ class HloComputation { // instruction pointer to location in the list for fast lookup. using InstructionList = std::list>; InstructionList instructions_; - std::unordered_map + absl::flat_hash_map instruction_iterators_; std::vector param_instructions_; diff --git a/tensorflow/compiler/xla/service/hlo_computation_test.cc b/tensorflow/compiler/xla/service/hlo_computation_test.cc index f7ed1b0316b213a0f34b1d690229f0173dbd5250..2aaaef1d36d58bcce18db4aa37ff05ea352e484b 100644 --- a/tensorflow/compiler/xla/service/hlo_computation_test.cc +++ b/tensorflow/compiler/xla/service/hlo_computation_test.cc @@ -601,8 +601,11 @@ TEST_F(HloComputationTest, Stringification) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); + PrecisionConfig precision_config; + precision_config.mutable_operand_precision()->Resize( + 2, PrecisionConfig::DEFAULT); builder.AddInstruction( - HloInstruction::CreateDot(sout, x, reshape, dot_dnums)); + HloInstruction::CreateDot(sout, x, reshape, dot_dnums, precision_config)); auto module = CreateNewModule(); auto* computation = module->AddEntryComputation(builder.Build()); @@ -633,8 +636,11 @@ TEST_F(HloComputationTest, StringificationIndent) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); + PrecisionConfig precision_config; + precision_config.mutable_operand_precision()->Resize( + 2, PrecisionConfig::DEFAULT); builder.AddInstruction( - HloInstruction::CreateDot(sout, x, reshape, dot_dnums)); + HloInstruction::CreateDot(sout, x, reshape, dot_dnums, precision_config)); auto module = CreateNewModule(); auto* computation = module->AddEntryComputation(builder.Build()); @@ -666,8 +672,11 @@ TEST_F(HloComputationTest, StringificationCanonical) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); + PrecisionConfig precision_config; + precision_config.mutable_operand_precision()->Resize( + 2, PrecisionConfig::DEFAULT); builder.AddInstruction( - HloInstruction::CreateDot(sout, x, reshape, dot_dnums)); + HloInstruction::CreateDot(sout, x, reshape, dot_dnums, precision_config)); auto module = CreateNewModule(); auto* computation = module->AddEntryComputation(builder.Build()); diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding.cc b/tensorflow/compiler/xla/service/hlo_constant_folding.cc index 2ed645c3aed525dea05604eefa24d49b54f8a5db..4f898ce61c3f36e83e4b13130a404dbb4a2c36c6 100644 --- a/tensorflow/compiler/xla/service/hlo_constant_folding.cc +++ b/tensorflow/compiler/xla/service/hlo_constant_folding.cc @@ -71,18 +71,40 @@ StatusOr HloConstantFolding::Run(HloModule* module) { // Broadcasts dramatically increase the size of constants, which is often // detrimental to performance and memory capacity, so do not fold // broadcasts. - if (instruction->opcode() == HloOpcode::kBroadcast) { + if (instruction->opcode() == HloOpcode::kBroadcast || + instruction->opcode() == HloOpcode::kIota) { continue; } - std::unique_ptr result = evaluator->TryEvaluate(instruction); + // Don't constant fold unless it's a net positive or the output is small. + if (ShapeUtil::IsArray(instruction->shape())) { + int64 elements_in_removed_operands = 0; + for (HloInstruction* operand : instruction->operands()) { + if (operand->user_count() == 1 && + ShapeUtil::IsArray(operand->shape())) { + elements_in_removed_operands += + ShapeUtil::ElementsIn(operand->shape()); + } + } + int64 elements_in_constant = + ShapeUtil::ElementsIn(instruction->shape()); + + static const int64 kMaximumConstantSizeElements = 2 * 1000 * 1000; + if (elements_in_constant > elements_in_removed_operands && + elements_in_constant > kMaximumConstantSizeElements) { + continue; + } + } + + Literal result; // Currently we skip unimplemented operations. // TODO(b/35975797): Fold constant computations for more operations. - if (result == nullptr) { + if (!evaluator->TryEvaluate(instruction, &result)) { VLOG(2) << "Constant folding failed for instruction: " << instruction->ToString(); continue; } + VLOG(4) << "Constant folded: " << instruction->ToString(); TF_RETURN_IF_ERROR(computation->ReplaceWithNewInstruction( instruction, HloInstruction::CreateConstant(std::move(result)))); diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding.h b/tensorflow/compiler/xla/service/hlo_constant_folding.h index 4557983a9c0b0006cc2189c96a88478d469475c1..4a624cc7b8483aaa834634185a23195e437bd4e4 100644 --- a/tensorflow/compiler/xla/service/hlo_constant_folding.h +++ b/tensorflow/compiler/xla/service/hlo_constant_folding.h @@ -23,7 +23,7 @@ namespace xla { // A pass which performs constant folding in order to avoid unnecessary // computation on constants. -class HloConstantFolding : public HloPassInterface { +class HloConstantFolding : public HloModulePass { public: absl::string_view name() const override { return "constant_folding"; } diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc index 7cd1481a8ad72f5a7ae6536621572ba537a103de..e45f905f7152c37a9ab2b41d407310671310c2a3 100644 --- a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc +++ b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc @@ -28,7 +28,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_pass_fix.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/types.h" @@ -37,7 +37,7 @@ namespace op = xla::testing::opcode_matchers; namespace xla { namespace { -using HloConstantFoldingTest = HloTestBase; +using HloConstantFoldingTest = HloVerifiedTestBase; TEST_F(HloConstantFoldingTest, ConvertF32ToS64) { HloComputation::Builder builder(TestName()); @@ -52,7 +52,7 @@ TEST_F(HloConstantFoldingTest, ConvertF32ToS64) { EXPECT_THAT(computation->root_instruction(), op::Convert(input)); HloConstantFolding const_folder; - TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module)); EXPECT_TRUE(result); EXPECT_THAT(computation->root_instruction(), op::Constant()); @@ -73,7 +73,7 @@ TEST_F(HloConstantFoldingTest, ConvertS64ToF32) { EXPECT_THAT(computation->root_instruction(), op::Convert(input)); HloConstantFolding const_folder; - TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module)); EXPECT_TRUE(result); EXPECT_THAT(computation->root_instruction(), op::Constant()); @@ -94,7 +94,7 @@ TEST_F(HloConstantFoldingTest, ConvertF32ArrayToS64Array) { EXPECT_THAT(computation->root_instruction(), op::Convert(input)); HloConstantFolding const_folder; - TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module)); EXPECT_TRUE(result); EXPECT_THAT(computation->root_instruction(), op::Constant()); @@ -105,8 +105,8 @@ TEST_F(HloConstantFoldingTest, ConvertF32ArrayToS64Array) { TEST_F(HloConstantFoldingTest, Concatenate) { const struct TestConfig { int concat_dimension; - tensorflow::gtl::ArraySlice dimensions; - tensorflow::gtl::ArraySlice concat_sizes; + absl::Span dimensions; + absl::Span concat_sizes; } test_configs[] = { {1, {11, 0, 7, 5, 9}, {2, 5, 7, 11}}, {3, {1, 4, 17, 0, 8}, {1, 3, 9, 12}}, @@ -134,7 +134,7 @@ TEST_F(HloConstantFoldingTest, Concatenate) { auto computation = module->AddEntryComputation(builder.Build()); HloConstantFolding const_folder; - TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module)); EXPECT_TRUE(result); HloInstruction* root = computation->root_instruction(); @@ -161,7 +161,7 @@ TEST_F(HloConstantFoldingTest, Slice) { auto computation = module->AddEntryComputation(builder.Build()); HloConstantFolding const_folder; - TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module)); EXPECT_TRUE(result); HloInstruction* root = computation->root_instruction(); @@ -175,7 +175,7 @@ TEST_F(HloConstantFoldingTest, TransposeConstantFold) { TF_ASSERT_OK_AND_ASSIGN(auto literal, LiteralUtil::CreateRandomLiteral( ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0)); - auto literal_clone = literal->Literal::CloneToUnique(); + auto literal_clone = literal.Clone(); HloInstruction* literal_instruction = builder.AddInstruction( HloInstruction::CreateConstant(std::move(literal))); Shape shape = ShapeUtil::MakeShape(F32, {8, 7, 11, 9, 5}); @@ -186,7 +186,7 @@ TEST_F(HloConstantFoldingTest, TransposeConstantFold) { auto computation = module->AddEntryComputation(builder.Build()); HloConstantFolding const_folder; - TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module)); EXPECT_TRUE(result); HloInstruction* root = computation->root_instruction(); @@ -196,9 +196,9 @@ TEST_F(HloConstantFoldingTest, TransposeConstantFold) { using NativeT = typename primitive_util::PrimitiveTypeToNative::type; bool matched = true; root->literal().EachCell( - [&](tensorflow::gtl::ArraySlice indices, NativeT value) { + [&](absl::Span indices, NativeT value) { std::vector rindexes = Permute(permutation, indices); - matched = matched && (value == literal_clone->Get(rindexes)); + matched = matched && (value == literal_clone.Get(rindexes)); }); EXPECT_TRUE(matched); } @@ -219,28 +219,47 @@ const char* const kConstantFoldReduce = R"( })"; TEST_F(HloConstantFoldingTest, ConstantFoldReduce) { - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseHloString(kConstantFoldReduce)); + ParseAndVerifyModule(kConstantFoldReduce); HloConstantFolding const_folder; - TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(&module())); EXPECT_TRUE(result); - EXPECT_EQ(6, module->entry_computation() + EXPECT_EQ(6, module() + .entry_computation() ->root_instruction() ->literal() .GetFirstElement()); } TEST_F(HloConstantFoldingTest, ConstantFoldReduceNoLayout) { - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseHloString(kConstantFoldReduce)); - HloInstruction* add = module->computations().begin()->root_instruction(); + ParseAndVerifyModule(kConstantFoldReduce); + HloInstruction* add = module().computations().begin()->root_instruction(); LayoutUtil::ClearLayout(add->mutable_shape()); HloConstantFolding const_folder; + TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(&module())); + EXPECT_FALSE(result); + + EXPECT_THAT(module().entry_computation()->root_instruction(), op::Reduce()); +} + +const char* const kConstantFoldLargePad = R"( + HloModule ConstantFoldLargePad + + ENTRY r { + a = f32[1,1,1] constant(f32[1,1,1]{{{7}}}) + b = f32[] constant(42) + ROOT pad = f32[2048,2048,128] pad(a, b), padding=1024_1023x1024_1023x64_63 + })"; + +TEST_F(HloConstantFoldingTest, DoesNotFoldLargePad) { + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kConstantFoldLargePad)); + HloConstantFolding const_folder; TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get())); EXPECT_FALSE(result); - EXPECT_THAT(module->entry_computation()->root_instruction(), op::Reduce()); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Pad(op::Constant(), op::Constant())); } } // namespace diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc index 5add4251ef73286285e525ec41ce43ecaea28641..23ab4cda93fc5d6979308bdf9a87f0a16d465154 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc @@ -227,6 +227,14 @@ Status HloCostAnalysis::HandleCopy(const HloInstruction*) { return Status::OK(); } +Status HloCostAnalysis::HandleDomain(const HloInstruction* domain) { + // Domain does not have any computation or data transfer. + current_should_compute_bottleneck_time_ = false; + current_properties_[kBytesAccessedKey] = 0; + current_properties_[kOptimalSecondsKey] = 0; + return Status::OK(); +} + Status HloCostAnalysis::HandleDot(const HloInstruction* dot) { const Shape& lhs_shape = dot->operand(0)->shape(); const Shape& rhs_shape = dot->operand(1)->shape(); @@ -274,15 +282,21 @@ Status HloCostAnalysis::HandleMap(const HloInstruction* map) { } Status HloCostAnalysis::HandleReduce(const HloInstruction* reduce) { - auto arg = reduce->operand(0); HloComputation* function = reduce->to_apply(); // Compute the cost of the user function. TF_ASSIGN_OR_RETURN(const Properties sub_properties, ProcessSubcomputation(function)); // Compute the cost of all elements for this Reduce operation. - int64 reduction_count = ShapeUtil::ElementsIn(arg->shape()) - - ShapeUtil::ElementsIn(reduce->shape()); + // This counts the number of times the reduction function is applied, so it + // does not need to be multiplied by the number of input tensors - that's + // already "priced in" by the sub-computation doing more work. + auto arg = reduce->operand(0); + auto output_shape = ShapeUtil::IsArray(reduce->shape()) + ? reduce->shape() + : reduce->shape().tuple_shapes(0); + int64 reduction_count = + ShapeUtil::ElementsIn(arg->shape()) - ShapeUtil::ElementsIn(output_shape); for (const auto& property : sub_properties) { if (property.first != kBytesAccessedKey) { current_properties_[property.first] = property.second * reduction_count; @@ -501,8 +515,9 @@ Status HloCostAnalysis::HandleConvolution(const HloInstruction* convolution) { valid_position_counts.push_back(valid_position_count); } - const int64 fma_count = - input_feature * output_feature * batch * Product(valid_position_counts); + const int64 fma_count = (input_feature / convolution->feature_group_count()) * + output_feature * batch * + Product(valid_position_counts); current_properties_[kFlopsKey] = fma_count * kFmaFlops; return Status::OK(); } @@ -543,6 +558,10 @@ Status HloCostAnalysis::HandleAllToAll(const HloInstruction* hlo) { return Status::OK(); } +Status HloCostAnalysis::HandleCollectivePermute(const HloInstruction* /*hlo*/) { + return Status::OK(); +} + Status HloCostAnalysis::HandleRng(const HloInstruction* random) { // TODO(b/26346211): Implement better estimates for the RNG cost, since the // cost changes with the implementation and the distribution. For now, assume @@ -645,6 +664,11 @@ Status HloCostAnalysis::HandleConditional(const HloInstruction* conditional) { } Status HloCostAnalysis::HandleGather(const HloInstruction* gather) { + // Gather doesn't read the whole input buffer, it's equivalent to a copy the + // size of the output shape and a read of the gather indices. + current_properties_[kBytesAccessedKey] = + GetShapeSize(gather->shape()) * 2 + + GetShapeSize(gather->operand(1)->shape()); // Gather does not issue any flops. return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.h b/tensorflow/compiler/xla/service/hlo_cost_analysis.h index 1bf1c4a315655e78e10a8a66b571347357cc23e9..46b4bbeef222e6de581360fc01b293e812f1dedd 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COST_ANALYSIS_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COST_ANALYSIS_H_ +#include "absl/types/span.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -23,7 +24,6 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" @@ -67,11 +67,13 @@ class HloCostAnalysis : public ConstDfsHloVisitor { Status HandleRecvDone(const HloInstruction* recv_done) override; Status HandleConvert(const HloInstruction* convert) override; Status HandleCopy(const HloInstruction* copy) override; + Status HandleDomain(const HloInstruction* domain) override; Status HandleDot(const HloInstruction* dot) override; Status HandleConvolution(const HloInstruction* convolution) override; Status HandleFft(const HloInstruction* fft) override; Status HandleCrossReplicaSum(const HloInstruction* crs) override; Status HandleAllToAll(const HloInstruction* hlo) override; + Status HandleCollectivePermute(const HloInstruction* hlo) override; Status HandleInfeed(const HloInstruction* infeed) override; Status HandleOutfeed(const HloInstruction* outfeed) override; Status HandleRng(const HloInstruction* random) override; diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc index 2c854eea18642eb7cb081b4fdfe3bc83627e41ae..802cdfc9e454cf05db18fad9bc7f44fdc146a92e 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc @@ -203,6 +203,35 @@ TEST_F(HloCostAnalysisTest, Convolution) { sizeof(float) * (10 * 20 + 3 * 3 + 8 * 18)); } +TEST_F(HloCostAnalysisTest, ConvolutionWithFeatureGroup) { + XlaBuilder builder("convolution"); + auto input = Parameter( + &builder, 0, + ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/120, /*y_dim=*/10, + /*x_dim=*/20}), + "input"); + auto kernel = Parameter( + &builder, 1, + ShapeUtil::MakeShape(F32, {/*p_dim=*/120, /*z_dim=*/1, /*y_dim=*/3, + /*x_dim=*/3}), + "kernel"); + Conv(input, kernel, {1, 1}, Padding::kValid, /*feature_group_count=*/120); + + // Run HLO cost analysis. + auto hlo_module = BuildHloGraph(&builder); + HloCostAnalysis analysis(ShapeSize); + ASSERT_IS_OK( + hlo_module->entry_computation()->root_instruction()->Accept(&analysis)); + + // Output shape is [1x120x8x18] and each output element requires (3x3) + // FMAs and one FMA is 2 flops. + EXPECT_EQ(analysis.flop_count(), 120 * 8 * 18 * 2 * 3 * 3); + + // Bytes accessed is sum of inputs and output. + EXPECT_EQ(analysis.bytes_accessed(), + sizeof(float) * (120 * 10 * 20 + 120 * 3 * 3 + 120 * 8 * 18)); +} + TEST_F(HloCostAnalysisTest, Reduce) { XlaBuilder builder("reduce"); auto input = @@ -415,7 +444,7 @@ TEST_F(FusionCostAnalysis, NoLayout) { TEST_F(HloCostAnalysisTest, TupleCost) { HloCostAnalysis analysis(ShapeSize); { - XlaBuilder builder("matmul"); + XlaBuilder builder("tuple"); auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {123}), "x"); auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {42}), "y"); Tuple(&builder, {x, y}); @@ -430,6 +459,30 @@ TEST_F(HloCostAnalysisTest, TupleCost) { EXPECT_EQ(analysis.bytes_accessed(), kPointerSize * 2); } +using DomainCostAnalysis = HloTestBase; +TEST_F(DomainCostAnalysis, DomainCost) { + HloCostAnalysis analysis(ShapeSize); + + HloComputation::Builder builder("domain"); + auto x = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {123}), "x")); + auto y = builder.AddInstruction( + HloInstruction::CreateParameter(1, ShapeUtil::MakeShape(F32, {42}), "y")); + auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({x, y})); + auto domain = builder.AddInstruction( + HloInstruction::CreateDomain(tuple->shape(), tuple, nullptr, nullptr)); + + auto hlo_module = CreateNewModule(); + hlo_module->AddEntryComputation(builder.Build()); + + EXPECT_EQ(hlo_module->entry_computation()->root_instruction(), domain); + ASSERT_IS_OK(domain->Accept(&analysis)); + + EXPECT_EQ(analysis.flop_count(*domain), 0); + EXPECT_EQ(analysis.transcendental_count(*domain), 0); + EXPECT_EQ(analysis.bytes_accessed(*domain), 0); +} + TEST_F(HloCostAnalysisTest, BaseDilatedConvolution) { XlaBuilder builder("BaseDilatedConvolution"); auto input = Parameter( @@ -503,5 +556,30 @@ TEST_F(HloCostAnalysisTest, DynamicUpdateSlice) { EXPECT_EQ(analysis.bytes_accessed(), 8); } +TEST_F(HloCostAnalysisTest, Gather) { + // Test the analysis on a gather. + XlaBuilder builder("gather"); + Shape operand_shape = ShapeUtil::MakeShape(S32, {3, 3}); + Shape indices_shape = ShapeUtil::MakeShape(S32, {2}); + + auto operand = Parameter(&builder, 0, operand_shape, "operand"); + auto indices = Parameter(&builder, 1, indices_shape, "indices"); + GatherDimensionNumbers dim_numbers; + dim_numbers.add_offset_dims(1); + dim_numbers.add_collapsed_slice_dims(0); + dim_numbers.add_start_index_map(0); + dim_numbers.set_index_vector_dim(1); + Gather(operand, indices, dim_numbers, {1, 3}); + + auto hlo_module = BuildHloGraph(&builder); + + // Run HLO cost analysis. + HloCostAnalysis analysis(ShapeSize); + ASSERT_IS_OK( + hlo_module->entry_computation()->root_instruction()->Accept(&analysis)); + + EXPECT_EQ(analysis.bytes_accessed(), 56); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.cc b/tensorflow/compiler/xla/service/hlo_creation_utils.cc index 0ceb6a29685aed5b9b8bbc25968a00a3c5b56118..b2005d3c210d4ae7e3702cb9624c3ad98056984c 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils.cc +++ b/tensorflow/compiler/xla/service/hlo_creation_utils.cc @@ -19,12 +19,12 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/shape_inference.h" #include "tensorflow/compiler/xla/util.h" namespace xla { using absl::StrCat; -using tensorflow::gtl::ArraySlice; StatusOr MakeBinaryHlo(HloOpcode opcode, HloInstruction* lhs, HloInstruction* rhs) { @@ -50,9 +50,9 @@ StatusOr MakePadHlo(HloInstruction* operand, } StatusOr MakeSliceHlo(HloInstruction* operand, - ArraySlice start_indices, - ArraySlice limit_indices, - ArraySlice strides) { + absl::Span start_indices, + absl::Span limit_indices, + absl::Span strides) { HloComputation* computation = operand->parent(); TF_ASSIGN_OR_RETURN(Shape slice_shape, ShapeInference::InferSliceShape( operand->shape(), start_indices, @@ -62,19 +62,22 @@ StatusOr MakeSliceHlo(HloInstruction* operand, } StatusOr MakeConvolveHlo( - HloInstruction* lhs, HloInstruction* rhs, const Window& window, - const ConvolutionDimensionNumbers& dimension_numbers) { + HloInstruction* lhs, HloInstruction* rhs, int64 feature_group_count, + const Window& window, const ConvolutionDimensionNumbers& dimension_numbers, + const PrecisionConfig& precision_config) { HloComputation* computation = lhs->parent(); CHECK_EQ(computation, rhs->parent()); - TF_ASSIGN_OR_RETURN(Shape convolve_shape, ShapeInference::InferConvolveShape( - lhs->shape(), rhs->shape(), - window, dimension_numbers)); + TF_ASSIGN_OR_RETURN(Shape convolve_shape, + ShapeInference::InferConvolveShape( + lhs->shape(), rhs->shape(), feature_group_count, + window, dimension_numbers)); return computation->AddInstruction(HloInstruction::CreateConvolve( - convolve_shape, lhs, rhs, window, dimension_numbers)); + convolve_shape, lhs, rhs, feature_group_count, window, dimension_numbers, + precision_config)); } StatusOr MakeTransposeHlo(HloInstruction* operand, - ArraySlice dimensions) { + absl::Span dimensions) { HloComputation* computation = operand->parent(); TF_ASSIGN_OR_RETURN( Shape transpose_shape, @@ -91,15 +94,15 @@ StatusOr MakeReshapeHlo(const Shape& result_shape, } StatusOr MakeReshapeHlo( - ArraySlice result_shape_dim_bounds, HloInstruction* operand) { + absl::Span result_shape_dim_bounds, HloInstruction* operand) { Shape new_shape = ShapeUtil::MakeShape(operand->shape().element_type(), result_shape_dim_bounds); return MakeReshapeHlo(new_shape, operand); } -StatusOr MakeDynamicSliceHlo(HloInstruction* operand, - HloInstruction* start_indices, - ArraySlice slice_sizes) { +StatusOr MakeDynamicSliceHlo( + HloInstruction* operand, HloInstruction* start_indices, + absl::Span slice_sizes) { HloComputation* computation = operand->parent(); CHECK_EQ(computation, start_indices->parent()); TF_ASSIGN_OR_RETURN( @@ -125,8 +128,8 @@ StatusOr MakeDynamicUpdateSliceHlo( } StatusOr MakeBroadcastHlo( - HloInstruction* operand, ArraySlice broadcast_dimensions, - ArraySlice result_shape_bounds) { + HloInstruction* operand, absl::Span broadcast_dimensions, + absl::Span result_shape_bounds) { HloComputation* computation = operand->parent(); Shape broadcast_shape = ShapeUtil::MakeShape(operand->shape().element_type(), result_shape_bounds); @@ -146,8 +149,8 @@ StatusOr MakeGetTupleElementHlo(HloInstruction* operand, HloInstruction::CreateGetTupleElement(gte_shape, operand, index)); } -StatusOr MakeConcatHlo(ArraySlice operands, - int64 dimension) { +StatusOr MakeConcatHlo( + absl::Span operands, int64 dimension) { CHECK_GT(operands.size(), 0); HloComputation* computation = operands[0]->parent(); @@ -166,19 +169,19 @@ StatusOr MakeConcatHlo(ArraySlice operands, } StatusOr MakeDotHlo(HloInstruction* lhs, HloInstruction* rhs, - const DotDimensionNumbers& dim_numbers) { + const DotDimensionNumbers& dim_numbers, + const PrecisionConfig& precision_config) { HloComputation* computation = lhs->parent(); CHECK_EQ(computation, rhs->parent()); TF_ASSIGN_OR_RETURN( Shape dot_shape, ShapeInference::InferDotOpShape(lhs->shape(), rhs->shape(), dim_numbers)); - return computation->AddInstruction( - HloInstruction::CreateDot(dot_shape, lhs, rhs, dim_numbers)); + return computation->AddInstruction(HloInstruction::CreateDot( + dot_shape, lhs, rhs, dim_numbers, precision_config)); } -StatusOr MakeMapHlo( - tensorflow::gtl::ArraySlice operands, - HloComputation* map_computation) { +StatusOr MakeMapHlo(absl::Span operands, + HloComputation* map_computation) { CHECK(!operands.empty()) << "Map Hlo requires at least one operand."; HloComputation* computation = operands.front()->parent(); std::vector operand_shapes; @@ -199,6 +202,44 @@ StatusOr MakeMapHlo( HloInstruction::CreateMap(map_shape, operands, map_computation)); } +StatusOr MakeReduceHlo(HloInstruction* operand, + HloInstruction* init_value, + HloOpcode binary_opcode, + HloModule* module) { + DCHECK_NE(nullptr, module); + std::vector all_dims(ShapeUtil::Rank(operand->shape())); + std::iota(all_dims.begin(), all_dims.end(), 0); + + auto scalar_shape = ShapeUtil::MakeShape(operand->shape().element_type(), {}); + HloComputation* reduce_computation; + { + HloComputation::Builder b(operand->name() + ".reduce_sub_computation"); + auto lhs = b.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "lhs")); + auto rhs = b.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape, "rhs")); + b.AddInstruction( + HloInstruction::CreateBinary(scalar_shape, binary_opcode, lhs, rhs)); + reduce_computation = module->AddEmbeddedComputation(b.Build()); + } + + return operand->parent()->AddInstruction(HloInstruction::CreateReduce( + scalar_shape, operand, init_value, all_dims, reduce_computation)); +} + +StatusOr MakeSelectHlo(HloInstruction* pred, + HloInstruction* on_true, + HloInstruction* on_false) { + HloComputation* computation = pred->parent(); + DCHECK_EQ(computation, on_true->parent()); + DCHECK_EQ(computation, on_false->parent()); + TF_ASSIGN_OR_RETURN(Shape select_shape, + ShapeInference::InferTernaryOpShape( + HloOpcode::kSelect, pred, on_true, on_false)); + return computation->AddInstruction(HloInstruction::CreateTernary( + select_shape, HloOpcode::kSelect, pred, on_true, on_false)); +} + StatusOr CollapseFirstNDims(HloInstruction* operand, int64 n) { CHECK_GT(n, 0); @@ -235,7 +276,7 @@ StatusOr PrependDegenerateDims(HloInstruction* operand, } StatusOr ExpandFirstDimIntoNDims( - HloInstruction* operand, ArraySlice expanded_dims) { + HloInstruction* operand, absl::Span expanded_dims) { CHECK_GT(operand->shape().dimensions_size(), 0); CHECK_EQ(operand->shape().dimensions(0), Product(expanded_dims)); @@ -251,8 +292,8 @@ StatusOr ExpandFirstDimIntoNDims( return MakeReshapeHlo(new_shape, operand); } -StatusOr ElideDegenerateDims(HloInstruction* operand, - ArraySlice dims_to_elide) { +StatusOr ElideDegenerateDims( + HloInstruction* operand, absl::Span dims_to_elide) { CHECK(absl::c_is_sorted(dims_to_elide)); const Shape& input_shape = operand->shape(); @@ -277,7 +318,7 @@ StatusOr ElideDegenerateDims(HloInstruction* operand, } StatusOr InsertDegenerateDims( - HloInstruction* operand, ArraySlice dims_to_insert) { + HloInstruction* operand, absl::Span dims_to_insert) { CHECK(absl::c_is_sorted(dims_to_insert)); const Shape& operand_shape = operand->shape(); @@ -319,26 +360,25 @@ StatusOr PadVectorWithZeros(HloInstruction* operand, padding_config_dim.set_edge_padding_high(zeros_to_append); *padding_config.add_dimensions() = padding_config_dim; - HloInstruction* zero = computation->AddInstruction( - HloInstruction::CreateConstant(absl::make_unique( - LiteralUtil::Zero(operand->shape().element_type())))); + HloInstruction* zero = + computation->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(operand->shape().element_type()))); return MakePadHlo(operand, zero, padding_config); } StatusOr BroadcastZeros( HloComputation* computation, PrimitiveType element_type, - ArraySlice broadcast_dimensions) { - HloInstruction* zero = - computation->AddInstruction(HloInstruction::CreateConstant( - absl::make_unique(LiteralUtil::Zero(element_type)))); + absl::Span broadcast_dimensions) { + HloInstruction* zero = computation->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::Zero(element_type))); return MakeBroadcastHlo(zero, /*broadcast_dimensions=*/{}, /*result_shape_bounds=*/broadcast_dimensions); } StatusOr> CreateComputationWithSignature( - ArraySlice domain, const Shape& range, + absl::Span domain, const Shape& range, absl::string_view name) { - HloComputation::Builder b{std::string(name)}; + HloComputation::Builder b{string(name)}; int64 param_idx = 0; for (const Shape* param_shape : domain) { b.AddInstruction(HloInstruction::CreateParameter( diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.h b/tensorflow/compiler/xla/service/hlo_creation_utils.h index 1bc6d09b4502c88d0d4e4e207075d64714190611..8e5ddbbd503a501bd493aec43a2ccd4db883ef0c 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils.h +++ b/tensorflow/compiler/xla/service/hlo_creation_utils.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CREATION_UTILS_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CREATION_UTILS_H_ +#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/statusor.h" @@ -40,21 +41,22 @@ StatusOr MakePadHlo(HloInstruction* operand, // Creates a slice HLO instruction and adds it to the computation containing // `operand`. -StatusOr MakeSliceHlo( - HloInstruction* operand, tensorflow::gtl::ArraySlice start_indices, - tensorflow::gtl::ArraySlice limit_indices, - tensorflow::gtl::ArraySlice strides); +StatusOr MakeSliceHlo(HloInstruction* operand, + absl::Span start_indices, + absl::Span limit_indices, + absl::Span strides); // Creates a convolution HLO instruction and adds it to the computation // containing `lhs` and `rhs` (`lhs` and `rhs` must be in the same computation). StatusOr MakeConvolveHlo( - HloInstruction* lhs, HloInstruction* rhs, const Window& window, - const ConvolutionDimensionNumbers& dimension_numbers); + HloInstruction* lhs, HloInstruction* rhs, int64 feature_group_count, + const Window& window, const ConvolutionDimensionNumbers& dimension_numbers, + const PrecisionConfig& precision_config); // Creates a transpose HLO instruction and adds it to the computation containing // `operand`. -StatusOr MakeTransposeHlo( - HloInstruction* operand, tensorflow::gtl::ArraySlice dimensions); +StatusOr MakeTransposeHlo(HloInstruction* operand, + absl::Span dimensions); // Creates a reshape HLO instruction and adds it to the computation containing // `operand`. @@ -62,15 +64,14 @@ StatusOr MakeReshapeHlo(const Shape& result_shape, HloInstruction* operand); StatusOr MakeReshapeHlo( - tensorflow::gtl::ArraySlice result_shape_dim_bounds, - HloInstruction* operand); + absl::Span result_shape_dim_bounds, HloInstruction* operand); // Creates a dynamic-slice HLO instruction and adds it to the computation // containing `operand` and `start_indices` (`operand` and `start_indices` must // be in the same computation). StatusOr MakeDynamicSliceHlo( HloInstruction* operand, HloInstruction* start_indices, - tensorflow::gtl::ArraySlice slice_sizes); + absl::Span slice_sizes); // Creates a dynamic-update-slice HLO instruction and adds it to the computation // containing `operand`, `update` and `start_indices` (`operand`, `update` and @@ -82,9 +83,8 @@ StatusOr MakeDynamicUpdateSliceHlo( // Creates a broadcast HLO instruction and adds it to the computation containing // `operand`. StatusOr MakeBroadcastHlo( - HloInstruction* operand, - tensorflow::gtl::ArraySlice broadcast_dimensions, - tensorflow::gtl::ArraySlice result_shape_bounds); + HloInstruction* operand, absl::Span broadcast_dimensions, + absl::Span result_shape_bounds); // Creates a GetTupleElement HLO instruction and adds it to the computation // containing `operand`. @@ -95,18 +95,47 @@ StatusOr MakeGetTupleElementHlo(HloInstruction* operand, // containing `operands` (`operands` must be non-empty and every element must be // contained in the same computation). StatusOr MakeConcatHlo( - tensorflow::gtl::ArraySlice operands, int64 dimension); + absl::Span operands, int64 dimension); // Creates a Dot HLO instruction and adds it to the computation containing `lhs` // and `rhs` (both must be in the same computation). StatusOr MakeDotHlo(HloInstruction* lhs, HloInstruction* rhs, - const DotDimensionNumbers& dim_numbers); + const DotDimensionNumbers& dim_numbers, + const PrecisionConfig& precision_config); // Creates a Map HLO instruction and adds it to the computation containing the // operands. All operands must be in the same computation. -StatusOr MakeMapHlo( - tensorflow::gtl::ArraySlice operands, - HloComputation* map_computation); +StatusOr MakeMapHlo(absl::Span operands, + HloComputation* map_computation); + +// Creates a Reduce HLO instruction and adds it to the computation containing +// the operand. This will create the sub-computation needed for the reduction in +// the given module. binary_opcode should represent a binary operation. +StatusOr MakeReduceHlo(HloInstruction* operand, + HloInstruction* init_value, + HloOpcode binary_opcode, + HloModule* module); + +// Creates a Select HLO instruction and adds it to the computation containing +// the predicate. The on_true and on_false instructions must also be contained +// in the same computation. +StatusOr MakeSelectHlo(HloInstruction* pred, + HloInstruction* on_true, + HloInstruction* on_false); + +// Creates an R1 Constant HLO instruction of the given PrimitiveType with the +// given values and adds it to the given computation. +template +StatusOr MakeR1ConstantHlo(HloComputation* computation, + PrimitiveType type, + absl::Span values) { + Literal literal = LiteralUtil::CreateR1(values); + if (literal.shape().element_type() != type) { + TF_ASSIGN_OR_RETURN(literal, literal.Convert(type)); + } + return computation->AddInstruction( + HloInstruction::CreateConstant(std::move(literal))); +} // ----------------------------------------------------------------------------- // Some other miscellaneous helpers to generate common HLO patterns. All of @@ -138,7 +167,7 @@ StatusOr PrependDegenerateDims(HloInstruction* operand, // For instance if `operand` has shape f32[200,9,7] and expanded_dims is // {2,5,20} the result is `operand` reshaped to [2,5,20,9,7]. StatusOr ExpandFirstDimIntoNDims( - HloInstruction* operand, tensorflow::gtl::ArraySlice expanded_dims); + HloInstruction* operand, absl::Span expanded_dims); // Elides (via reshape) a set of degenerate dimensions (dimensions containing // exactly one element), `dims_to_elide` from `operand`. Every dimension in @@ -148,7 +177,7 @@ StatusOr ExpandFirstDimIntoNDims( // For example if `operand` is of shape f32[19,1,20,1,7,1,9] and dims_to_elide // is {1,5} then the result is `operand` reshaped to [19,20,1,7,9]. StatusOr ElideDegenerateDims( - HloInstruction* operand, tensorflow::gtl::ArraySlice dims_to_elide); + HloInstruction* operand, absl::Span dims_to_elide); // Inserts (via reshape) a set of degenerate dimensions (dimensions containing // exactly one element), `dims_to_insert` into `operand`. The dimensions in @@ -158,7 +187,7 @@ StatusOr ElideDegenerateDims( // For example, if `operand` is of shape f32[12,21,8,34] and dims_to_insert is // {0, 2}, then the result is `operand` reshaped to [1,12,1,21,8,34]. StatusOr InsertDegenerateDims( - HloInstruction* operand, tensorflow::gtl::ArraySlice dims_to_insert); + HloInstruction* operand, absl::Span dims_to_insert); // Pads `operand` (which must have rank 1) with `zeros_to_prepend` zeros in the // front and `zeros_to_append` zeros in the back. @@ -171,12 +200,12 @@ StatusOr PadVectorWithZeros(HloInstruction* operand, // broadcast instruction is emitted into `computation`. StatusOr BroadcastZeros( HloComputation* computation, PrimitiveType element_type, - tensorflow::gtl::ArraySlice broadcast_dimensions); + absl::Span broadcast_dimensions); // Creates a HLO computation that takes arguments of type `domain` and produces // a value of type `range`. StatusOr> CreateComputationWithSignature( - tensorflow::gtl::ArraySlice domain, const Shape& range, + absl::Span domain, const Shape& range, absl::string_view name); } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc b/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc index a8de285d16fdf6c5824f4076860b57b3fdc279a0..e07a196d1154dc0ea45ccd2f15b0b9b56f7c41f8 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc +++ b/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc @@ -19,18 +19,17 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/core/platform/test.h" namespace xla { namespace { -using tensorflow::gtl::ArraySlice; -class HloCreationUtilsTest : public HloTestBase { +class HloCreationUtilsTest : public HloVerifiedTestBase { protected: - std::unique_ptr CreateModuleWithProgramShape( - PrimitiveType primitive_type, ArraySlice input_shape_dims, - ArraySlice output_shape_dims, HloInstruction** param, + HloModule* CreateModuleWithProgramShape( + PrimitiveType primitive_type, absl::Span input_shape_dims, + absl::Span output_shape_dims, HloInstruction** param, HloComputation** entry_computation) { Shape input_shape = ShapeUtil::MakeShape(primitive_type, input_shape_dims); Shape output_shape = @@ -48,27 +47,27 @@ TEST_F(HloCreationUtilsTest, CollapseFirst1Dim) { HloInstruction* param; HloComputation* entry_computation; - std::unique_ptr module = CreateModuleWithProgramShape( - S32, - /*input_shape_dims=*/{2}, /*output_shape_dims=*/{2}, ¶m, - &entry_computation); + HloModule* module = CreateModuleWithProgramShape(S32, + /*input_shape_dims=*/{2}, + /*output_shape_dims=*/{2}, + ¶m, &entry_computation); TF_ASSERT_OK_AND_ASSIGN(HloInstruction * first_1_dims_collapsed, CollapseFirstNDims(param, 1)); entry_computation->set_root_instruction(first_1_dims_collapsed); HloEvaluator evaluator; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result_literal, - evaluator.Evaluate>( + TF_ASSERT_OK_AND_ASSIGN(Literal result_literal, + evaluator.Evaluate( *module, {LiteralUtil::CreateR1({3, 4})})); - CHECK_EQ(*result_literal, *LiteralUtil::CreateR1({3, 4})); + CHECK_EQ(result_literal, LiteralUtil::CreateR1({3, 4})); } TEST_F(HloCreationUtilsTest, CollapseFirst2Dims) { HloInstruction* param; HloComputation* entry_computation; - std::unique_ptr module = CreateModuleWithProgramShape( + HloModule* module = CreateModuleWithProgramShape( S32, /*input_shape_dims=*/{2, 3, 2}, /*output_shape_dims=*/{6, 2}, ¶m, &entry_computation); @@ -79,13 +78,13 @@ TEST_F(HloCreationUtilsTest, CollapseFirst2Dims) { HloEvaluator evaluator; TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr result_literal, - evaluator.Evaluate>( + Literal result_literal, + evaluator.Evaluate( *module, {LiteralUtil::CreateR3( {{{1, 2}, {3, 4}, {5, 6}}, {{-1, -2}, {-3, -4}, {-5, -6}}})})); - CHECK_EQ(*result_literal, - *LiteralUtil::CreateR2( + CHECK_EQ(result_literal, + LiteralUtil::CreateR2( {{1, 2}, {3, 4}, {5, 6}, {-1, -2}, {-3, -4}, {-5, -6}})); } @@ -93,10 +92,10 @@ TEST_F(HloCreationUtilsTest, Prepend1DegenerateDim) { HloInstruction* param; HloComputation* entry_computation; - std::unique_ptr module = CreateModuleWithProgramShape( - S32, - /*input_shape_dims=*/{2}, /*output_shape_dims=*/{1, 2}, ¶m, - &entry_computation); + HloModule* module = CreateModuleWithProgramShape(S32, + /*input_shape_dims=*/{2}, + /*output_shape_dims=*/{1, 2}, + ¶m, &entry_computation); TF_ASSERT_OK_AND_ASSIGN(HloInstruction * with_1_degenerate_dim_prepended, PrependDegenerateDims(param, 1)); @@ -104,17 +103,17 @@ TEST_F(HloCreationUtilsTest, Prepend1DegenerateDim) { HloEvaluator evaluator; TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr result_literal, - evaluator.Evaluate>( - *module, {LiteralUtil::CreateR1({9, 10})})); - CHECK_EQ(*result_literal, *LiteralUtil::CreateR2({{9, 10}})); + Literal result_literal, + evaluator.Evaluate(*module, + {LiteralUtil::CreateR1({9, 10})})); + CHECK_EQ(result_literal, LiteralUtil::CreateR2({{9, 10}})); } TEST_F(HloCreationUtilsTest, Prepend2DegenerateDims) { HloInstruction* param; HloComputation* entry_computation; - std::unique_ptr module = CreateModuleWithProgramShape( + HloModule* module = CreateModuleWithProgramShape( S32, /*input_shape_dims=*/{2}, /*output_shape_dims=*/{1, 1, 2}, ¶m, &entry_computation); @@ -125,37 +124,37 @@ TEST_F(HloCreationUtilsTest, Prepend2DegenerateDims) { HloEvaluator evaluator; TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr result_literal, - evaluator.Evaluate>( - *module, {LiteralUtil::CreateR1({9, 10})})); - CHECK_EQ(*result_literal, *LiteralUtil::CreateR3({{{9, 10}}})); + Literal result_literal, + evaluator.Evaluate(*module, + {LiteralUtil::CreateR1({9, 10})})); + CHECK_EQ(result_literal, LiteralUtil::CreateR3({{{9, 10}}})); } TEST_F(HloCreationUtilsTest, Prepend2DegenerateDimsToScalar) { HloInstruction* param; HloComputation* entry_computation; - std::unique_ptr module = CreateModuleWithProgramShape( - S32, - /*input_shape_dims=*/{}, /*output_shape_dims=*/{1, 1}, ¶m, - &entry_computation); + HloModule* module = CreateModuleWithProgramShape(S32, + /*input_shape_dims=*/{}, + /*output_shape_dims=*/{1, 1}, + ¶m, &entry_computation); TF_ASSERT_OK_AND_ASSIGN(HloInstruction * with_2_degenerate_dims_prepended, PrependDegenerateDims(param, 2)); entry_computation->set_root_instruction(with_2_degenerate_dims_prepended); HloEvaluator evaluator; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result_literal, - evaluator.Evaluate>( - *module, {LiteralUtil::CreateR0(9)})); - CHECK_EQ(*result_literal, *LiteralUtil::CreateR2({{9}})); + TF_ASSERT_OK_AND_ASSIGN( + Literal result_literal, + evaluator.Evaluate(*module, {LiteralUtil::CreateR0(9)})); + CHECK_EQ(result_literal, LiteralUtil::CreateR2({{9}})); } TEST_F(HloCreationUtilsTest, ExpandFirstDimInto3Dims) { HloInstruction* param; HloComputation* entry_computation; - std::unique_ptr module = CreateModuleWithProgramShape( + HloModule* module = CreateModuleWithProgramShape( S32, /*input_shape_dims=*/{6}, /*output_shape_dims=*/{3, 1, 2}, ¶m, &entry_computation); @@ -166,21 +165,21 @@ TEST_F(HloCreationUtilsTest, ExpandFirstDimInto3Dims) { HloEvaluator evaluator; TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr result_literal, - evaluator.Evaluate>( + Literal result_literal, + evaluator.Evaluate( *module, {LiteralUtil::CreateR1({1, 2, 3, 4, 5, 6})})); - CHECK_EQ(*result_literal, - *LiteralUtil::CreateR3({{{1, 2}}, {{3, 4}}, {{5, 6}}})); + CHECK_EQ(result_literal, + LiteralUtil::CreateR3({{{1, 2}}, {{3, 4}}, {{5, 6}}})); } TEST_F(HloCreationUtilsTest, PadVectorWithZeros) { HloInstruction* param; HloComputation* entry_computation; - std::unique_ptr module = CreateModuleWithProgramShape( - S32, - /*input_shape_dims=*/{2}, /*output_shape_dims=*/{6}, ¶m, - &entry_computation); + HloModule* module = CreateModuleWithProgramShape(S32, + /*input_shape_dims=*/{2}, + /*output_shape_dims=*/{6}, + ¶m, &entry_computation); TF_ASSERT_OK_AND_ASSIGN( HloInstruction * zero_padded_param, @@ -188,20 +187,20 @@ TEST_F(HloCreationUtilsTest, PadVectorWithZeros) { entry_computation->set_root_instruction(zero_padded_param); HloEvaluator evaluator; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result_literal, - evaluator.Evaluate>( + TF_ASSERT_OK_AND_ASSIGN(Literal result_literal, + evaluator.Evaluate( *module, {LiteralUtil::CreateR1({3, 4})})); - CHECK_EQ(*result_literal, *LiteralUtil::CreateR1({0, 0, 0, 3, 4, 0})); + CHECK_EQ(result_literal, LiteralUtil::CreateR1({0, 0, 0, 3, 4, 0})); } TEST_F(HloCreationUtilsTest, BroadcastZeros_S32) { HloInstruction* param; HloComputation* entry_computation; - std::unique_ptr module = CreateModuleWithProgramShape( - S32, - /*input_shape_dims=*/{}, /*output_shape_dims=*/{2, 2}, ¶m, - &entry_computation); + HloModule* module = CreateModuleWithProgramShape(S32, + /*input_shape_dims=*/{}, + /*output_shape_dims=*/{2, 2}, + ¶m, &entry_computation); TF_ASSERT_OK_AND_ASSIGN( HloInstruction * zeros, @@ -209,20 +208,20 @@ TEST_F(HloCreationUtilsTest, BroadcastZeros_S32) { entry_computation->set_root_instruction(zeros); HloEvaluator evaluator; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result_literal, - evaluator.Evaluate>( - *module, {LiteralUtil::CreateR0(0)})); - CHECK_EQ(*result_literal, *LiteralUtil::CreateR2({{0, 0}, {0, 0}})); + TF_ASSERT_OK_AND_ASSIGN( + Literal result_literal, + evaluator.Evaluate(*module, {LiteralUtil::CreateR0(0)})); + CHECK_EQ(result_literal, LiteralUtil::CreateR2({{0, 0}, {0, 0}})); } TEST_F(HloCreationUtilsTest, BroadcastZeros_F32) { HloInstruction* param; HloComputation* entry_computation; - std::unique_ptr module = CreateModuleWithProgramShape( - F32, - /*input_shape_dims=*/{}, /*output_shape_dims=*/{2, 2}, ¶m, - &entry_computation); + HloModule* module = CreateModuleWithProgramShape(F32, + /*input_shape_dims=*/{}, + /*output_shape_dims=*/{2, 2}, + ¶m, &entry_computation); TF_ASSERT_OK_AND_ASSIGN( HloInstruction * zeros, @@ -230,11 +229,11 @@ TEST_F(HloCreationUtilsTest, BroadcastZeros_F32) { entry_computation->set_root_instruction(zeros); HloEvaluator evaluator; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result_literal, - evaluator.Evaluate>( + TF_ASSERT_OK_AND_ASSIGN(Literal result_literal, + evaluator.Evaluate( *module, {LiteralUtil::CreateR0(0.0f)})); - CHECK_EQ(*result_literal, - *LiteralUtil::CreateR2({{0.0f, 0.0f}, {0.0f, 0.0f}})); + CHECK_EQ(result_literal, + LiteralUtil::CreateR2({{0.0f, 0.0f}, {0.0f, 0.0f}})); } } // namespace diff --git a/tensorflow/compiler/xla/service/hlo_cse.cc b/tensorflow/compiler/xla/service/hlo_cse.cc index cb367adf5ef29111838dd6ee1b770394eef1301c..e602107cbe64320a8e8e740168cb294ec6be9667 100644 --- a/tensorflow/compiler/xla/service/hlo_cse.cc +++ b/tensorflow/compiler/xla/service/hlo_cse.cc @@ -23,6 +23,8 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_set.h" +#include "absl/container/inlined_vector.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -33,8 +35,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/gtl/flatset.h" -#include "tensorflow/core/lib/gtl/inlined_vector.h" #include "tensorflow/core/lib/hash/hash.h" namespace xla { @@ -137,8 +137,8 @@ StatusOr HloCSE::Run(HloModule* module) { // HLO instructions are grouped into equivalency classes by using the // cse_equal predicate defined above. This set holds a representative // instruction for each class. - tensorflow::gtl::FlatSet + absl::flat_hash_set representatives(/*N=*/computation->instruction_count() + 1, &CseHash, cse_equal); for (auto instruction : computation->MakeInstructionPostOrder()) { diff --git a/tensorflow/compiler/xla/service/hlo_cse.h b/tensorflow/compiler/xla/service/hlo_cse.h index a28c03599a8765da708f37b986010713654647cb..e4857fd3fdd9a329b013ac8215cb6d36d73c4b7d 100644 --- a/tensorflow/compiler/xla/service/hlo_cse.h +++ b/tensorflow/compiler/xla/service/hlo_cse.h @@ -25,7 +25,7 @@ namespace xla { // and identical instructions with the same operands are commoned. The pass // iterates over the instructions in topological order which enables the pass to // find arbitrarily large common expressions. -class HloCSE : public HloPassInterface { +class HloCSE : public HloModulePass { public: // If is_layout_sensitive is true, then the simplifier preserves layout during // transformation. Otherwise, layout is ignored. diff --git a/tensorflow/compiler/xla/service/hlo_cse_test.cc b/tensorflow/compiler/xla/service/hlo_cse_test.cc index 406d712ec6783a310aabc6600b8b70e1a1ae30a9..9b18b0284f63c25934c1b7118dc8973caa62cadc 100644 --- a/tensorflow/compiler/xla/service/hlo_cse_test.cc +++ b/tensorflow/compiler/xla/service/hlo_cse_test.cc @@ -29,7 +29,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/util.h" @@ -44,7 +44,7 @@ namespace op = xla::testing::opcode_matchers; namespace xla { namespace { -class HloCseTest : public HloTestBase { +class HloCseTest : public HloVerifiedTestBase { protected: HloCseTest() {} }; @@ -65,15 +65,15 @@ TEST_F(HloCseTest, CombineTwoConstants) { EXPECT_EQ(3, computation->instruction_count()); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); + EXPECT_TRUE(cse.Run(module).ValueOrDie()); EXPECT_EQ(2, computation->instruction_count()); HloInstruction* constant = *computation->instructions().begin(); EXPECT_EQ(42.0f, constant->literal().Get({})); - auto result = ExecuteAndTransfer(std::move(module), {}); + auto result = ExecuteAndTransfer(module->Clone(), {}); auto expected = LiteralUtil::CreateR0(84.0); - EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(1e-4))); + EXPECT_TRUE(LiteralTestUtil::Near(expected, result, ErrorSpec(1e-4))); } TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndInsensitive) { @@ -96,16 +96,16 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndInsensitive) { EXPECT_THAT(add, op::Add(constant1, constant2)); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); + EXPECT_TRUE(cse.Run(module).ValueOrDie()); EXPECT_EQ(2, computation->instruction_count()); auto first_operand = add->operand(0); EXPECT_THAT(first_operand, ::testing::AnyOf(constant1, constant2)); EXPECT_THAT(add, op::Add(first_operand, first_operand)); - auto result = ExecuteAndTransfer(std::move(module), {}); + auto result = ExecuteAndTransfer(module->Clone(), {}); auto expected = LiteralUtil::CreateR2({{2.0, 4.0}, {6.0, 8.0}}); - EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(1e-4))); + EXPECT_TRUE(LiteralTestUtil::Near(expected, result, ErrorSpec(1e-4))); } TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndSensitive) { @@ -128,14 +128,14 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndSensitive) { EXPECT_THAT(add, op::Add(constant1, constant2)); HloCSE cse(/*is_layout_sensitive=*/true); - EXPECT_FALSE(cse.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(cse.Run(module).ValueOrDie()); EXPECT_EQ(3, computation->instruction_count()); EXPECT_THAT(add, op::Add(constant1, constant2)); - auto result = ExecuteAndTransfer(std::move(module), {}); + auto result = ExecuteAndTransfer(module->Clone(), {}); auto expected = LiteralUtil::CreateR2({{2.0, 4.0}, {6.0, 8.0}}); - EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(1e-4))); + EXPECT_TRUE(LiteralTestUtil::Near(expected, result, ErrorSpec(1e-4))); } TEST_F(HloCseTest, ConstantsSameValueDifferentType) { @@ -177,7 +177,7 @@ TEST_F(HloCseTest, ConstantsSameValueDifferentType) { EXPECT_EQ(20, computation->instruction_count()); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); + EXPECT_TRUE(cse.Run(module).ValueOrDie()); // CSE will remove both the second float(42.0f) and the corresponding // convert/cast. @@ -209,7 +209,7 @@ TEST_F(HloCseTest, NonscalarConstants) { op::Tuple(common_constant1, common_constant2, uncommon_constant)); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); + EXPECT_TRUE(cse.Run(module).ValueOrDie()); EXPECT_EQ(3, computation->instruction_count()); auto first_operand = tuple->operand(0); @@ -240,7 +240,7 @@ TEST_F(HloCseTest, IdenticalInstructions) { EXPECT_THAT(tuple, op::Tuple(exp1, exp2, exp3)); HloCSE cse(/*is_layout_sensitive=*/true); - EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); + EXPECT_TRUE(cse.Run(module).ValueOrDie()); EXPECT_EQ(3, computation->instruction_count()); auto first_operand = tuple->operand(0); @@ -250,7 +250,7 @@ TEST_F(HloCseTest, IdenticalInstructions) { // Test two identical while loops with same inputs TEST_F(HloCseTest, WhileLoopsIdenticalConditionsAndBodiesSameInput) { - auto module = ParseHloString(R"( + ParseAndVerifyModule(R"( HloModule WhileLoopsIdenticalConditionsAndBodiesSameInput %body (param: (f32[], f32[])) -> (f32[], f32[]) { @@ -278,21 +278,20 @@ f32[]) while((f32[], f32[]) %tuple.1), condition=%condition, body=%body ROOT %while.1 = (f32[], f32[]) while((f32[], f32[]) %tuple.1), condition=%condition.1, body=%body } - )") - .ValueOrDie(); + )"); - auto computation = module->entry_computation(); + auto computation = module().entry_computation(); EXPECT_EQ(5, computation->instruction_count()); HloCSE cse(true); - EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); + EXPECT_TRUE(cse.Run(&module()).ValueOrDie()); EXPECT_EQ(4, computation->instruction_count()); } // Test two while loops with same conditions, same inputs, but different // bodies TEST_F(HloCseTest, WhileLoopsIdenticalConditionsSameInputAndDifferentBodies) { - auto module = ParseHloString(R"( + ParseAndVerifyModule(R"( HloModule WhileLoopsIdenticalConditionsSameInputAndDifferentBodies %body (param: (f32[], f32[])) -> (f32[], f32[]) { @@ -329,20 +328,19 @@ index=1 %sub = f32[] subtract(f32[] %get-tuple-element.2, f32[] condition=%condition, body=%body ROOT %while.1 = (f32[], f32[]) while((f32[], f32[]) %tuple.1), condition=%condition.1, body=%body2 } - )") - .ValueOrDie(); + )"); - auto computation = module->entry_computation(); + auto computation = module().entry_computation(); EXPECT_EQ(5, computation->instruction_count()); HloCSE cse(true); - EXPECT_FALSE(cse.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(cse.Run(&module()).ValueOrDie()); EXPECT_EQ(5, computation->instruction_count()); } // Test two identical while loops with different inputs TEST_F(HloCseTest, WhileLoopsIdenticalConditionsAndBodiesDifferentInput) { - auto module = ParseHloString(R"( + ParseAndVerifyModule(R"( HloModule WhileLoopsIdenticalConditionsAndBodiesDifferentInput %body (param: (f32[], f32[])) -> (f32[], f32[]) { @@ -373,21 +371,20 @@ f32[] constant(2) %tuple.2 = (f32[], f32[]) tuple(f32[] %constant.4, f32[] condition=%condition.1, body=%body } - )") - .ValueOrDie(); + )"); - auto computation = module->entry_computation(); + auto computation = module().entry_computation(); EXPECT_EQ(8, computation->instruction_count()); HloCSE cse(true); - EXPECT_FALSE(cse.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(cse.Run(&module()).ValueOrDie()); EXPECT_EQ(8, computation->instruction_count()); } // Test two while loops with identical bodies and same inputs, but different // conditions TEST_F(HloCseTest, WhileLoopsIdenticalBodiesAndInputDifferntConditions) { - auto module = ParseHloString(R"( + ParseAndVerifyModule(R"( HloModule WhileLoopsIdenticalBodiesAndInputDifferntConditions %body (param: (f32[], f32[])) -> (f32[], f32[]) { @@ -414,14 +411,13 @@ f32[]) { %constant.2 = f32[] constant(1) %constant.3 = f32[] constant(2) %while = (f32[], f32[]) while((f32[], f32[]) %tuple.1), condition=%condition, body=%body ROOT %while.1 = (f32[], f32[]) while((f32[], f32[]) %tuple.1), condition=%condition.1, body=%body - })") - .ValueOrDie(); + })"); - auto computation = module->entry_computation(); + auto computation = module().entry_computation(); EXPECT_EQ(5, computation->instruction_count()); HloCSE cse(true); - EXPECT_FALSE(cse.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(cse.Run(&module()).ValueOrDie()); EXPECT_EQ(5, computation->instruction_count()); } @@ -450,7 +446,7 @@ TEST_F(HloCseTest, IdenticalInstructionsDifferentLayoutsSensitive) { EXPECT_THAT(tuple, op::Tuple(exp1, exp2)); HloCSE cse(/*is_layout_sensitive=*/true); - EXPECT_FALSE(cse.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(cse.Run(module).ValueOrDie()); EXPECT_EQ(4, computation->instruction_count()); EXPECT_THAT(tuple, op::Tuple(exp1, exp2)); @@ -481,7 +477,7 @@ TEST_F(HloCseTest, IdenticalInstructionsDifferentLayoutsInsensitive) { EXPECT_THAT(tuple, op::Tuple(exp1, exp2)); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); + EXPECT_TRUE(cse.Run(module).ValueOrDie()); EXPECT_EQ(3, computation->instruction_count()); auto first_operand = tuple->operand(0); @@ -516,7 +512,7 @@ TEST_F(HloCseTest, FusionInternalCSE) { EXPECT_EQ(5, fused_computation->instruction_count()); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); + EXPECT_TRUE(cse.Run(module).ValueOrDie()); EXPECT_EQ(4, fused_computation->instruction_count()); auto root = fused_computation->root_instruction(); @@ -565,7 +561,7 @@ TEST_F(HloCseTest, IdenticalExpressions) { EXPECT_THAT(tuple, op::Tuple(op::Add(negate1, exp1), op::Add(negate2, exp2))); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); + EXPECT_TRUE(cse.Run(module).ValueOrDie()); EXPECT_EQ(5, computation->instruction_count()); auto operand = tuple->operand(0); @@ -599,7 +595,7 @@ TEST_F(HloCseTest, DoNotCombineRng) { uint32 count_before = computation->instruction_count(); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_FALSE(cse.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(cse.Run(module).ValueOrDie()); uint32 count_after = computation->instruction_count(); EXPECT_EQ(count_before, count_after); @@ -653,7 +649,7 @@ TEST_F(HloCseTest, DoNotCombineCallsToImpureFunctions) { VLOG(3) << "before: " << module->ToString(); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_FALSE(cse.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(cse.Run(module).ValueOrDie()); VLOG(3) << "after: " << module->ToString(); @@ -663,7 +659,7 @@ TEST_F(HloCseTest, DoNotCombineCallsToImpureFunctions) { } TEST_F(HloCseTest, CompareComputations) { - auto module = ParseHloString(R"( + ParseAndVerifyModule(R"( HloModule m add_computation { @@ -684,12 +680,11 @@ TEST_F(HloCseTest, CompareComputations) { r1 = f32[] reduce(p, c), dimensions={0}, to_apply=add_computation r2 = f32[] reduce(p, c), dimensions={0}, to_apply=add_computation2 ROOT f2 = (f32[],f32[]) tuple(r1, r2) - })") - .ValueOrDie(); + })"); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); - HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_TRUE(cse.Run(&module()).ValueOrDie()); + HloInstruction* root = module().entry_computation()->root_instruction(); EXPECT_EQ(root->operand(0), root->operand(1)); } @@ -708,13 +703,13 @@ TEST_F(HloCseTest, ConstantsSameValueInDifferentDomains) { EXPECT_EQ(2, computation->instruction_count()); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_FALSE(cse.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(cse.Run(module).ValueOrDie()); EXPECT_EQ(2, computation->instruction_count()); } TEST_F(HloCseTest, Domain) { - auto module = ParseHloString(R"( + ParseAndVerifyModule(R"( HloModule module ENTRY %entry { %param = f32[] parameter(0), sharding={maximal device=0} @@ -735,13 +730,11 @@ ENTRY %entry { domain={kind="sharding", entry={maximal device=2}, exit={maximal device=0}} %add = f32[] add(%domain.3, %domain.4) ROOT %sub = f32[] subtract(%add, %domain.5) -})") - .ValueOrDie(); +})"); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); - LOG(INFO) << "AAAAA " << module->ToString(); - const HloInstruction* sub = module->entry_computation()->root_instruction(); + EXPECT_TRUE(cse.Run(&module()).ValueOrDie()); + const HloInstruction* sub = module().entry_computation()->root_instruction(); const HloInstruction* add = sub->operand(0); EXPECT_EQ(add->operand(0), add->operand(1)); EXPECT_NE(add->operand(0), sub->operand(1)); diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc index 1d35757b424bba1e175e7006593b0026527eb62b..5dcf6bc985ff18fa6fc1ab5a5692914b4597d065 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" @@ -46,8 +47,7 @@ namespace { // // In this case, we should be able to reuse p0 and output, although p0 has // multiple uses. -bool MultiDynamicSliceUseShareSameIndices( - tensorflow::gtl::ArraySlice uses) { +bool MultiDynamicSliceUseShareSameIndices(absl::Span uses) { if (uses.empty()) { return false; } @@ -92,7 +92,7 @@ HloDataflowAnalysis::HloDataflowAnalysis( bool HloDataflowAnalysis::AreTransitiveUsesElementwiseOrTuple( const HloInstruction* inst) { - tensorflow::gtl::FlatSet visited; + absl::flat_hash_set visited; absl::InlinedVector stack; stack.push_back(inst); while (!stack.empty()) { @@ -126,7 +126,7 @@ bool HloDataflowAnalysis::ValueIsDefinedAt(const HloInstruction* instruction, const HloValue& HloDataflowAnalysis::GetValueDefinedAt( const HloInstruction* instruction, const ShapeIndex& index) const { - CHECK(ValueIsDefinedAt(instruction, index)); + CHECK(ValueIsDefinedAt(instruction, index)) << instruction->ToString(); return GetUniqueValueAt(instruction, index); } @@ -160,8 +160,8 @@ void HloDataflowAnalysis::MarkValueForDeletion(HloValue::Id value_id) { void HloDataflowAnalysis::DeleteMarkedValues() { #ifndef NDEBUG // Verify that no marked-for-deletion values are in any of the value sets. - tensorflow::gtl::FlatSet id_set(value_ids_to_delete_.begin(), - value_ids_to_delete_.end()); + absl::flat_hash_set id_set(value_ids_to_delete_.begin(), + value_ids_to_delete_.end()); for (const auto& pair : value_sets_) { const HloInstruction* instruction = pair.first; const InstructionValueSet& instruction_value_set = pair.second; @@ -221,7 +221,7 @@ string HloDataflowAnalysis::ToString() const { bool HloDataflowAnalysis::Phi( HloInstruction* instruction, - tensorflow::gtl::ArraySlice inputs) { + absl::Span inputs) { CHECK(ssa_form_); VLOG(4) << "Phi(" << instruction->name() << ")"; VLOG(5) << "instruction value set = " @@ -356,23 +356,6 @@ bool HloDataflowAnalysis::UpdateBitcastValueSet(HloInstruction* bitcast) { return false; } -bool HloDataflowAnalysis::UpdateSliceValueSet(HloInstruction* slice) { - CHECK_EQ(slice->opcode(), HloOpcode::kSlice); - if (!slice->IsInPlaceSlice()) { - return false; - } - // If this slice is lowered to an in-place version, then it forwards the - // operand value to the output. - const InstructionValueSet& operand_set = - GetInstructionValueSet(slice->operand(0)); - InstructionValueSet& slice_set = GetInstructionValueSet(slice); - if (operand_set != slice_set) { - slice_set = operand_set; - return true; - } - return false; -} - bool HloDataflowAnalysis::UpdateSendValueSet(HloInstruction* send) { CHECK_EQ(send->opcode(), HloOpcode::kSend); bool changed = false; @@ -641,8 +624,6 @@ bool HloDataflowAnalysis::UpdateInstructionValueSet( switch (instruction->opcode()) { case HloOpcode::kBitcast: return UpdateBitcastValueSet(instruction); - case HloOpcode::kSlice: - return UpdateSliceValueSet(instruction); case HloOpcode::kDomain: return UpdateDomainValueSet(instruction); case HloOpcode::kCopy: @@ -674,7 +655,7 @@ bool HloDataflowAnalysis::UpdateInstructionValueSet( void HloDataflowAnalysis::Propagate() { std::queue worklist; - tensorflow::gtl::FlatSet workset; + absl::flat_hash_set workset; auto add_to_worklist = [&worklist, &workset](HloInstruction* instruction) { if (workset.insert(instruction).second) { worklist.push(instruction); @@ -814,11 +795,6 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() { define_all_values(); } break; - case HloOpcode::kSlice: - if (!instruction->IsInPlaceSlice()) { - define_all_values(); - } - break; case HloOpcode::kWhile: case HloOpcode::kCall: case HloOpcode::kConditional: @@ -837,7 +813,7 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() { return Unimplemented( "Computation %s is called in both a parallel (eg, kMap) and " "sequential (eg, kCall) context", - computation->name().c_str()); + computation->name()); } if (call_graph_node.caller_callsites().empty() || call_graph_node.context() == CallContext::kParallel) { @@ -1072,6 +1048,7 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser( } if (user->opcode() == HloOpcode::kDynamicUpdateSlice || + user->opcode() == HloOpcode::kScatter || user->opcode() == HloOpcode::kWhile) { // We eliminated other users in BufferLiveness::live_range_strictly_before, // so here we just need to check that the use is at operand index 0. diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h index a1678d4943c7c722df38c4dc93e284d614279217..abac398c04fc4c418d8814a0097db4434bc1cd9c 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h @@ -25,6 +25,7 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/service/call_graph.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" @@ -34,7 +35,6 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/macros.h" namespace xla { @@ -182,7 +182,6 @@ class HloDataflowAnalysis { // Updates the value set for a particular instruction type. Returns whether // the instruction value set changed. bool UpdateBitcastValueSet(HloInstruction* bitcast); - bool UpdateSliceValueSet(HloInstruction* slice); bool UpdateCallValueSet(HloInstruction* call); bool UpdateConditionalValueSet(HloInstruction* conditional); bool UpdateCopyValueSet(HloInstruction* copy); @@ -202,7 +201,7 @@ class HloDataflowAnalysis { // the given instruction. If skip_top_level is true, then the top level of the // value set of 'instruction' is not modified. bool Phi(HloInstruction* instruction, - tensorflow::gtl::ArraySlice inputs); + absl::Span inputs); // Updates the positions of the HloValues in the output of the given // instruction. This should be called after the instruction value set of diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc index d1a96c10f88e3c05e21a6db4eccb46683cd64c4a..909853106d57d181e85e3e4134b4039be2b176f5 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/test.h" @@ -1261,9 +1262,10 @@ TEST_P(HloDataflowAnalysisTest, MultipleEntryParameters_Sequential) { auto entry = module_->AddEntryComputation(builder.Build()); RunAnalysis(GetParam()); - SequentialHloOrdering::HloModuleSequence sequence; - sequence.insert({entry, {param0, negate, param1, exp, add}}); - SequentialHloOrdering ordering(module_.get(), sequence); + HloSchedule schedule(module_.get()); + schedule.set_sequence(entry, {param0, negate, param1, exp, add}); + TF_ASSERT_OK(schedule.Verify()); + SequentialHloOrdering ordering(schedule); // Entry parameters interfere as if they are defined simultaneously at // the very beginning. @@ -1339,14 +1341,16 @@ TEST_P(HloDataflowAnalysisTest, WhileParameters_Sequential) { bool ssa_form = GetParam(); RunAnalysis(ssa_form); - SequentialHloOrdering::HloModuleSequence sequence; - sequence.insert({entry, {param, xla_while}}); - sequence.insert({condition, {cond_param, cond_constant}}); + HloSchedule schedule(module_.get()); + schedule.set_sequence(entry, {param, xla_while}); + schedule.set_sequence(condition, {cond_param, cond_constant}); // Construct the order such that 'constant' and its use 'exp' are before // body_param. - sequence.insert({body, {constant, exp, body_param, add}}); + schedule.set_sequence( + body, {constant, exp, body_param, add, dead_constant, dead_negate}); + TF_ASSERT_OK(schedule.Verify()); - SequentialHloOrdering ordering(module_.get(), sequence); + SequentialHloOrdering ordering(schedule); // 'add' is live out of the body and will interfere with an later instructions // such as 'dead_constant' and 'dead_negate'. @@ -1476,11 +1480,10 @@ TEST_P(HloDataflowAnalysisTest, OverlappedValuesSequentialOrder) { auto entry = module_->AddEntryComputation(builder.Build()); RunAnalysis(GetParam()); - SequentialHloOrdering::HloModuleSequence sequence; - std::vector order = {param, negate, exp, add}; - sequence.emplace(entry, order); - - SequentialHloOrdering ordering(module_.get(), sequence); + HloSchedule schedule(module_.get()); + schedule.set_sequence(entry, {param, negate, exp, add}); + TF_ASSERT_OK(schedule.Verify()); + SequentialHloOrdering ordering(schedule); EXPECT_TRUE(InstructionsMayInterfere(ordering, param, negate)); EXPECT_FALSE(InstructionsMayInterfere(ordering, param, exp)); @@ -2280,6 +2283,44 @@ TEST_F(CanShareOperandBufferWithUserTest, DynamicUpdateSliceCanShare) { dataflow_analysis_->CanShareOperandBufferWithUser(starts, {}, dus, {})); } +TEST_F(CanShareOperandBufferWithUserTest, ScatterCanShare) { + const char* hlo_text = R"( + HloModule TensorFlowScatterV1 + + update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) + } + + ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + updates = s32[2,3] parameter(2) + ROOT scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 + } + )"; + TF_ASSERT_OK_AND_ASSIGN(module_, ParseHloString(hlo_text)); + computation_ = module_->entry_computation(); + RunAnalysis(); + + HloInstruction* operand_param = computation_->parameter_instruction(0); + HloInstruction* indices_param = computation_->parameter_instruction(1); + HloInstruction* updates_param = computation_->parameter_instruction(2); + HloInstruction* scatter = computation_->root_instruction(); + + EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser( + operand_param, {}, scatter, {})); + EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser( + indices_param, {}, scatter, {})); + EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser( + updates_param, {}, scatter, {})); +} + TEST_F(CanShareOperandBufferWithUserTest, SortCanShare) { auto builder = HloComputation::Builder(TestName()); @@ -2305,7 +2346,8 @@ TEST_F(CanShareOperandBufferWithUserTest, SortCanShareWithTupleUser) { auto values = builder.AddInstruction( HloInstruction::CreateParameter(1, values_shape, "values")); auto sort = builder.AddInstruction(HloInstruction::CreateSort( - ShapeUtil::MakeTupleShape({keys_shape, values_shape}), 0, keys, values)); + ShapeUtil::MakeTupleShape({keys_shape, values_shape}), 0, keys, + {values})); BuildModuleAndRunAnalysis(builder.Build()); @@ -2334,8 +2376,11 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); + PrecisionConfig precision_config; + precision_config.mutable_operand_precision()->Resize( + 2, PrecisionConfig::DEFAULT); auto dot = builder.AddInstruction( - HloInstruction::CreateDot(data_shape, a, b, dot_dnums)); + HloInstruction::CreateDot(data_shape, a, b, dot_dnums, precision_config)); auto one = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); diff --git a/tensorflow/compiler/xla/service/hlo_dce.h b/tensorflow/compiler/xla/service/hlo_dce.h index 1fe69b1395753a612499e6e87bfc22f8ac8e767b..401204267282b294ca9f701e29e9edd9f0f35b98 100644 --- a/tensorflow/compiler/xla/service/hlo_dce.h +++ b/tensorflow/compiler/xla/service/hlo_dce.h @@ -33,7 +33,7 @@ namespace xla { // // This pass does not remove dead parameter instructions, as parameter // instructions cannot be deleted. -class HloDCE : public HloPassInterface { +class HloDCE : public HloModulePass { public: ~HloDCE() override {} absl::string_view name() const override { return "dce"; } diff --git a/tensorflow/compiler/xla/service/hlo_domain_isolator.h b/tensorflow/compiler/xla/service/hlo_domain_isolator.h index d36631fc2f16902ed8f1f89f903027081f9b3801..c0bf1b9e16b52d81365db277abeb06defeb12d44 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_isolator.h +++ b/tensorflow/compiler/xla/service/hlo_domain_isolator.h @@ -30,7 +30,7 @@ namespace xla { // used to break an HLO graph edge connecting two instructions with different // sharding. If a set of connected instructions have all the same sharding, no // kDomain instruction will be placed. -class HloDomainIsolator : public HloPassInterface { +class HloDomainIsolator : public HloModulePass { public: // Creates a new kDomain instruction for the edge between the use instruction // (the first HloInstruction argument), and the operand instruction (the diff --git a/tensorflow/compiler/xla/service/hlo_domain_map.cc b/tensorflow/compiler/xla/service/hlo_domain_map.cc index edf0073f3091ef4da7ced3f13b56961a7db4b430..c6d02f9f67bb599e496d20fc2acf2e627ed54438 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_map.cc +++ b/tensorflow/compiler/xla/service/hlo_domain_map.cc @@ -17,6 +17,8 @@ limitations under the License. #include +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -40,17 +42,22 @@ namespace xla { return std::move(domain_map); } -bool HloDomainMap::InSameDomain(HloInstruction* instruction1, - HloInstruction* instruction2) const { +bool HloDomainMap::InSameDomain(const HloInstruction* instruction1, + const HloInstruction* instruction2) const { int64 domain_id1 = GetDomainId(instruction1); int64 domain_id2 = GetDomainId(instruction2); return domain_id1 >= 0 && domain_id1 == domain_id2; } -int64 HloDomainMap::GetDomainId(HloInstruction* instruction) const { +int64 HloDomainMap::GetDomainId(const HloInstruction* instruction) const { return FindOrDefault(instruction_to_domain_, instruction, -1); } +int64 HloDomainMap::GetDomainMetadataId( + const HloInstruction* instruction) const { + return FindOrDie(domain_metadata_id_, instruction); +} + Status HloDomainMap::TryProcessEmptyDomain(HloInstruction* instruction) { TF_RET_CHECK(instruction->opcode() == HloOpcode::kDomain); // We only check operands, so we are sure to not process the empty domain from @@ -72,6 +79,11 @@ Status HloDomainMap::TryProcessEmptyDomain(HloInstruction* instruction) { } Status HloDomainMap::Populate(HloComputation* computation) { + InstructionOrderMap instructions_post_order; + int64 count = 0; + for (HloInstruction* instruction : computation->MakeInstructionPostOrder()) { + instructions_post_order.insert(std::make_pair(instruction, count++)); + } for (HloInstruction* instruction : computation->instructions()) { if (IsDomainInstruction(instruction)) { // If this is a kDomain of the kind we are currently processing, check @@ -85,9 +97,46 @@ Status HloDomainMap::Populate(HloComputation* computation) { continue; } TF_ASSIGN_OR_RETURN(std::unique_ptr domain, - CreateDomain(instruction)); + CreateDomain(instruction, instructions_post_order)); TF_RETURN_IF_ERROR(InsertDomain(std::move(domain))); } + TF_RETURN_IF_ERROR(PopulateDomainMetadataMap()); + return Status::OK(); +} + +Status HloDomainMap::PopulateDomainMetadataMap() { + auto hash = [](const DomainMetadata* m) { return m->Hash(); }; + auto equal = [](const DomainMetadata* a, const DomainMetadata* b) { + return a->Matches(*b); + }; + absl::flat_hash_map + domain_metadata(1024, hash, equal); + + for (auto& domain : instruction_domains_) { + int64 domain_metadata_id = -1; + if (!domain->enter_domains.empty()) { + const HloInstruction* domain_instruction = *domain->enter_domains.begin(); + domain_metadata_id = + domain_metadata + .insert({&domain_instruction->user_side_metadata(), + domain_metadata.size() + 1}) + .first->second; + } else if (!domain->exit_domains.empty()) { + const HloInstruction* domain_instruction = *domain->exit_domains.begin(); + domain_metadata_id = + domain_metadata + .insert({&domain_instruction->operand_side_metadata(), + domain_metadata.size() + 1}) + .first->second; + } else { + domain_metadata_id = 0; + } + TF_RET_CHECK(domain_metadata_id >= 0); + for (HloInstruction* instruction : domain->instructions) { + domain_metadata_id_[instruction] = domain_metadata_id; + } + } return Status::OK(); } @@ -143,14 +192,17 @@ Status HloDomainMap::ExpandDomain(HloInstruction* instruction, } StatusOr> HloDomainMap::CreateDomain( - HloInstruction* instruction) const { + HloInstruction* instruction, + const InstructionOrderMap& instructions_order) const { auto domain = absl::make_unique(); TF_RETURN_IF_ERROR(ExpandDomain(instruction, domain.get())); - domain->instructions = MakeNonDomainInstructions(domain->reach_set); + domain->instructions = + MakeNonDomainInstructions(domain->reach_set, instructions_order); return std::move(domain); } -bool HloDomainMap::IsDomainInstruction(HloInstruction* instruction) const { +bool HloDomainMap::IsDomainInstruction( + const HloInstruction* instruction) const { if (instruction->opcode() != HloOpcode::kDomain) { return false; } @@ -168,7 +220,8 @@ bool HloDomainMap::IsDomainInstruction(HloInstruction* instruction) const { /* static */ std::vector HloDomainMap::MakeNonDomainInstructions( - const tensorflow::gtl::FlatSet& instruction_set) { + const absl::flat_hash_set& instruction_set, + const InstructionOrderMap& instructions_order) { std::vector instructions; instructions.reserve(instruction_set.size()); for (HloInstruction* instruction : instruction_set) { @@ -176,9 +229,10 @@ HloDomainMap::MakeNonDomainInstructions( instructions.push_back(instruction); } } + // sort instructions according to instructions_order std::sort(instructions.begin(), instructions.end(), - [](HloInstruction* a, HloInstruction* b) { - return a->unique_id() < b->unique_id(); + [&instructions_order](HloInstruction* a, HloInstruction* b) { + return instructions_order.at(a) < instructions_order.at(b); }); return instructions; } diff --git a/tensorflow/compiler/xla/service/hlo_domain_map.h b/tensorflow/compiler/xla/service/hlo_domain_map.h index 1ca71597253eecfb45ae8f384240033a57045277..bce7d1aa7cf1822ef1608674e7bf9483c628e4b5 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_map.h +++ b/tensorflow/compiler/xla/service/hlo_domain_map.h @@ -19,14 +19,14 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_domain_metadata.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/flatmap.h" -#include "tensorflow/core/lib/gtl/flatset.h" namespace xla { @@ -58,18 +58,27 @@ class HloDomainMap { } // Checks whether two instructions are within the same domain. - bool InSameDomain(HloInstruction* instruction1, - HloInstruction* instruction2) const; + bool InSameDomain(const HloInstruction* instruction1, + const HloInstruction* instruction2) const; // Checks whether instruction is a kDomain instruction of the kind we are // currently processing. - bool IsDomainInstruction(HloInstruction* instruction) const; + bool IsDomainInstruction(const HloInstruction* instruction) const; // Retrieves the domain identifier of the instruction, or -1 in case // instruction is not found within any domain. - int64 GetDomainId(HloInstruction* instruction) const; + int64 GetDomainId(const HloInstruction* instruction) const; + + // Returns the unique id of the domain metadata for the domain the given + // instruction belongs to. The given instruction must not be a kDomain + // instruction since each domain instruction is associated with 2 domains. + int64 GetDomainMetadataId(const HloInstruction* instruction) const; private: + // Map used for representing instruction ordering, i.e. + // order_map[a] < order_map[b] means a must be ordered before b. + using InstructionOrderMap = absl::flat_hash_map; + HloDomainMap(string domain_kind) : domain_kind_(std::move(domain_kind)) {} // Check if the kDomain instruction is facing (via its operand link) another @@ -95,16 +104,23 @@ class HloDomainMap { // Creates a domain data structure using the ExpandDomain() API. StatusOr> CreateDomain( - HloInstruction* instruction) const; + HloInstruction* instruction, + const InstructionOrderMap& instructions_order) const; // Out of an instruction set, returns a vector of all the ones which are not // a kDomain kind. static std::vector MakeNonDomainInstructions( - const tensorflow::gtl::FlatSet& instruction_set); + const absl::flat_hash_set& instruction_set, + const InstructionOrderMap& instructions_order); + + // Populates domain_metadata_id_ that maps each HloInstruction to the unique + // ID of its associated domain metatadata. + Status PopulateDomainMetadataMap(); string domain_kind_; std::vector> instruction_domains_; - tensorflow::gtl::FlatMap instruction_to_domain_; + absl::flat_hash_map instruction_to_domain_; + absl::flat_hash_map domain_metadata_id_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_domain_metadata.h b/tensorflow/compiler/xla/service/hlo_domain_metadata.h index 575149c8b8455e0bf36840ba9e62ef2a5028e2f5..d3c83c15ae3be67a64f3dc4bcb0312ae9fbc33e4 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_metadata.h +++ b/tensorflow/compiler/xla/service/hlo_domain_metadata.h @@ -20,11 +20,11 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_set.h" #include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/flatset.h" namespace xla { @@ -42,9 +42,12 @@ class DomainMetadata { // operand/user pathways, without crossing a kDomain instruction of a given // kind. The reach_set can contain kDomain instructions of other kinds, if // two domains of different kind intersect each other. - tensorflow::gtl::FlatSet reach_set; + absl::flat_hash_set reach_set; - // The same instructions in reach_set, but purged from kDomain instructions. + // The same instructions in reach_set, but purged from kDomain instructions + // and ordered according to their computation graph post-order, i.e. + // if instructions[pos_a] depends on instructions[pos_b], then pos_a > + // pos_b. std::vector instructions; // If we consider a graph edge as an arrow oriented from the operand to the @@ -52,8 +55,8 @@ class DomainMetadata { // whose dataflow enters the reach set (domain), while the exit_domains // contains the set of kDomain instructions whose dataflow exit the reach // set. - tensorflow::gtl::FlatSet enter_domains; - tensorflow::gtl::FlatSet exit_domains; + absl::flat_hash_set enter_domains; + absl::flat_hash_set exit_domains; }; virtual ~DomainMetadata() = default; @@ -69,6 +72,9 @@ class DomainMetadata { // two matches. virtual bool Matches(const DomainMetadata& other) const = 0; + // Returns the hash value of the metadata. + virtual size_t Hash() const = 0; + // Returns a string representation of the metadata. virtual string ToString() const = 0; }; diff --git a/tensorflow/compiler/xla/service/hlo_domain_remover.h b/tensorflow/compiler/xla/service/hlo_domain_remover.h index 97bc8ef604092acc849b55b09af8a24bf775529e..0fc30fb86c337a8bba5957d504caa7deeac9b86c 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_remover.h +++ b/tensorflow/compiler/xla/service/hlo_domain_remover.h @@ -26,7 +26,7 @@ namespace xla { // Removes all the kDomain instructions of a given kind from the input module, // and calls the normalizer to propagate the properties on the possibly new born // instructions. -class HloDomainRemover : public HloPassInterface { +class HloDomainRemover : public HloModulePass { public: // Creates a new HloDomainRemover object tasked at removing all the kDomain // instructions of a given kind. diff --git a/tensorflow/compiler/xla/service/hlo_domain_test.cc b/tensorflow/compiler/xla/service/hlo_domain_test.cc index 79e78ee2d052cfc6c9553e88e7945644aedc37cd..43e74d2f6f07bd685ad8683401138a4f06cd2ad2 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_test.cc +++ b/tensorflow/compiler/xla/service/hlo_domain_test.cc @@ -29,11 +29,6 @@ namespace xla { namespace { class HloDomainTest : public HloVerifiedTestBase { - public: - HloDomainTest() - : HloVerifiedTestBase(/*layout_sensitive=*/false, - /*allow_mixed_precision=*/false) {} - protected: bool FindUserViaDomainPath(HloInstruction* instruction, HloInstruction* operand) const { @@ -104,6 +99,8 @@ class OpNameMetadata : public DomainMetadata { static absl::string_view KindName() { return "opname"; } + size_t Hash() const override { return std::hash()(opname_); } + private: string opname_; }; @@ -350,7 +347,8 @@ ENTRY entry { token = token[] after-all() infeed = ((f32[4], f32[4]), token[]) infeed(token), sharding={{maximal device=1}, {maximal device=0}, {maximal device=0}} - infeed.data = (f32[4], f32[4]) get-tuple-element(infeed), index=0 + infeed.data = (f32[4], f32[4]) get-tuple-element(infeed), index=0, + sharding={{maximal device=1}, {maximal device=0}} gte0 = f32[4] get-tuple-element(infeed.data), index=0 gte1 = f32[4] get-tuple-element(infeed.data), index=1 copy0 = f32[4] copy(gte0) @@ -384,11 +382,8 @@ ENTRY entry { // \ / // TUPLE // | - HloInstruction* infeed = FindInstruction(module, "infeed"); - ASSERT_NE(infeed, nullptr); - HloInstruction* infeed_data = - infeed->parent()->AddInstruction(HloInstruction::CreateGetTupleElement( - ShapeUtil::GetTupleElementShape(infeed->shape(), 0), infeed, 0)); + HloInstruction* infeed_data = FindInstruction(module, "infeed.data"); + ASSERT_NE(infeed_data, nullptr); auto infeed_data_users = infeed_data->users(); HloInstruction* new_gte0 = infeed_data->parent()->AddInstruction( @@ -496,6 +491,7 @@ TEST_F(HloDomainTest, DumpParseNullSharding) { ASSERT_TRUE(ParseModule(hlo_string).status().ok()); } +// Tuple inputs are domain instructions. TEST_F(HloDomainTest, DomainTuple) { const char* const hlo_string = R"( HloModule Module @@ -503,7 +499,8 @@ HloModule Module ENTRY entry { p0 = f32[4] parameter(0), sharding={maximal device=0} cst = u32[] constant(0), sharding={maximal device=1} - tpl = (u32[], f32[4]) tuple(cst, p0), sharding={{maximal device=1}, {maximal device=0}} + tpl = (u32[], f32[4]) tuple(cst, p0), + sharding={{maximal device=1}, {maximal device=0}} ROOT gte = f32[4] get-tuple-element(tpl), index=1, sharding={maximal device=0} } )"; @@ -588,5 +585,109 @@ ENTRY %entry (p0: (f32[4], f32[4])) -> (f32[4], f32[4], f32[4]) { EXPECT_FALSE(HasDomainEdge(module, "d", "c")); } +// Emulate instructions inserted at top and bottom within nested tuple domain. +TEST_F(HloDomainTest, DomainTupleTopBottomInsert) { + const char* const hlo_string = R"( +HloModule Module + +ENTRY entry { + p0 = f32[4] parameter(0), sharding={maximal device=1} + p1 = (f32[5], f32[6]) parameter(1), + sharding={{maximal device=1}, {maximal device=0}} + tuple.0 = (f32[4], (f32[5], f32[6])) tuple(p0, p1), + sharding={{maximal device=1}, {maximal device=1}, {maximal device=0}} + ROOT res = (f32[5], f32[6]) get-tuple-element(tuple.0), index=1, + sharding={{maximal device=1}, {maximal device=0}} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); + + HloDomainIsolator isolator(ShardingDomainCreator{}); + TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module)); + EXPECT_TRUE(isolator_changed); + + // Clear sharding of tuple.0 instruction, in order to test domain sharding + // application. + auto tuple0 = FindInstruction(module, "tuple.0"); + tuple0->clear_sharding(); + + // Insert the following instructons above and below tuple.0, to emulate other + // passes effects: + // COPY.0 + // \ / + // TUPLE.0 + // / \ + // COPY.1 \ + // / \ + // GTE.0 GTE.1 + // | | + // | COPY.2 + // \ / + // \ / + // TUPLE.1 + // | + auto tuple0_users = tuple0->users(); + auto computation = tuple0->parent(); + HloInstruction* copy0 = computation->AddInstruction( + HloInstruction::CreateUnary(tuple0->operand(1)->shape(), HloOpcode::kCopy, + tuple0->mutable_operand(1))); + TF_EXPECT_OK(tuple0->ReplaceOperandWith(1, copy0)); + + HloInstruction* copy1 = computation->AddInstruction( + HloInstruction::CreateUnary(tuple0->shape(), HloOpcode::kCopy, tuple0)); + HloInstruction* gte0 = + computation->AddInstruction(HloInstruction::CreateGetTupleElement( + ShapeUtil::GetTupleElementShape(copy1->shape(), 0), copy1, 0)); + HloInstruction* gte1 = + computation->AddInstruction(HloInstruction::CreateGetTupleElement( + ShapeUtil::GetTupleElementShape(tuple0->shape(), 1), tuple0, 1)); + HloInstruction* copy2 = computation->AddInstruction( + HloInstruction::CreateUnary(gte1->shape(), HloOpcode::kCopy, gte1)); + HloInstruction* tuple1 = + computation->AddInstruction(HloInstruction::CreateTuple({gte0, copy2})); + + for (HloInstruction* user : tuple0_users) { + TF_EXPECT_OK(tuple0->ReplaceUseWith(user, tuple1)); + } + + HloDomainRemover remover(ShardingMetadata::KindName(), + ShardingMetadata::NormalizeShardingDomain); + TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module)); + EXPECT_TRUE(remover_changed); + + EXPECT_TRUE(tuple0->has_sharding()); + EXPECT_EQ(HloSharding::Tuple(tuple0->shape(), {HloSharding::AssignDevice(1), + HloSharding::AssignDevice(1), + HloSharding::AssignDevice(0)}), + tuple0->sharding()); + + EXPECT_TRUE(copy0->has_sharding()); + EXPECT_EQ(HloSharding::Tuple(copy0->shape(), {HloSharding::AssignDevice(1), + HloSharding::AssignDevice(0)}), + copy0->sharding()); + + // copy1 has partial information only from gte.0, so in the end it gets no + // sharding at all. During propagation it does propagate the information from + // gte.0 though, enabling Tuple.0 to be fully sharded. + EXPECT_FALSE(copy1->has_sharding()); + + EXPECT_TRUE(gte0->has_sharding()); + EXPECT_EQ(HloSharding::AssignDevice(1), gte0->sharding()); + + EXPECT_TRUE(gte1->has_sharding()); + EXPECT_EQ(HloSharding::Tuple(gte1->shape(), {HloSharding::AssignDevice(1), + HloSharding::AssignDevice(0)}), + gte1->sharding()); + + EXPECT_TRUE(copy2->has_sharding()); + EXPECT_EQ(HloSharding::Tuple(copy2->shape(), {HloSharding::AssignDevice(1), + HloSharding::AssignDevice(0)}), + copy2->sharding()); + + EXPECT_TRUE(tuple1->has_sharding()); + EXPECT_EQ(tuple0->sharding(), tuple1->sharding()); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_domain_verifier.h b/tensorflow/compiler/xla/service/hlo_domain_verifier.h index 81d6d69a8c59da2fc77cb2bab808602cd964fdaf..bea5cba38d018029c9805e1593fadad54460447e 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_verifier.h +++ b/tensorflow/compiler/xla/service/hlo_domain_verifier.h @@ -29,7 +29,7 @@ namespace xla { // Verifies that the domain instructions are consistent, and the each domain is // surrounded by the same metadata. -class HloDomainVerifier : public HloPassInterface { +class HloDomainVerifier : public HloModulePass { public: HloDomainVerifier(std::vector kinds) : kinds_(std::move(kinds)) {} diff --git a/tensorflow/compiler/xla/service/hlo_element_type_converter.cc b/tensorflow/compiler/xla/service/hlo_element_type_converter.cc index b9244b8e9e5f34e7ac5113c8eacb6f8243eea314..72006e17e7e7ec09b62e88d05b695ec9f4c49647 100644 --- a/tensorflow/compiler/xla/service/hlo_element_type_converter.cc +++ b/tensorflow/compiler/xla/service/hlo_element_type_converter.cc @@ -151,7 +151,11 @@ StatusOr HloElementTypeConverter::Run(HloModule* module) { } TF_RET_CHECK(hlo->called_computations().empty()) << hlo->ToString(); - if (!HasOperandType(hlo, eliminate_type_)) { + bool nullary = hlo->operands().empty(); + bool wrong_element_type = hlo->shape().element_type() == eliminate_type_; + bool should_eliminate_type = (nullary && wrong_element_type) || + HasOperandType(hlo, eliminate_type_); + if (!should_eliminate_type) { // If this CHECK fires, then this was an instruction that does not take // the elimination type as an operand but it does return it. This pass // does not have a feature to change the output type in that case, so diff --git a/tensorflow/compiler/xla/service/hlo_element_type_converter.h b/tensorflow/compiler/xla/service/hlo_element_type_converter.h index 44ded2c2faf7c38d1e2f2aae577ddc07089bbb6a..4d2a942925288ba4c3977ffcd25b55746a555a5e 100644 --- a/tensorflow/compiler/xla/service/hlo_element_type_converter.h +++ b/tensorflow/compiler/xla/service/hlo_element_type_converter.h @@ -25,7 +25,7 @@ namespace xla { // inserting Convert ops. This allows a backend to support an element type while // only actually implementing the Convert op for that element type. This is // generally not the fastest approach, but it works. -class HloElementTypeConverter : public HloPassInterface { +class HloElementTypeConverter : public HloModulePass { public: // eliminate_type is the type to eliminate as the input or output of ops, // using Convert ops to replace it with replace_with_type. diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index ca1c4dd0e9bc7286704ef31ee3dfdc63b6c154b8..c2998883851481b3cda5a3423baa3454018117b2 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -24,6 +24,7 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/container/inlined_vector.h" #include "absl/memory/memory.h" #include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/index_util.h" @@ -53,12 +54,9 @@ namespace xla { namespace { -using tensorflow::gtl::ArraySlice; - template -StatusOr> Compare(const Shape& shape, HloOpcode opcode, - LiteralSlice lhs_literal, - LiteralSlice rhs_literal) { +StatusOr Compare(const Shape& shape, HloOpcode opcode, + LiteralSlice lhs_literal, LiteralSlice rhs_literal) { std::function compare_op; switch (opcode) { case HloOpcode::kEq: @@ -96,19 +94,20 @@ StatusOr> Compare(const Shape& shape, HloOpcode opcode, << HloOpcodeString(opcode); } - auto result = absl::make_unique(shape); - TF_RETURN_IF_ERROR(result->Populate([&](ArraySlice multi_index) { - return compare_op(lhs_literal.Get(multi_index), - rhs_literal.Get(multi_index)); - })); + Literal result(shape); + TF_RETURN_IF_ERROR( + result.Populate([&](absl::Span multi_index) { + return compare_op(lhs_literal.Get(multi_index), + rhs_literal.Get(multi_index)); + })); return std::move(result); } template <> -StatusOr> Compare( - const Shape& shape, HloOpcode opcode, LiteralSlice lhs_literal, - LiteralSlice rhs_literal) { +StatusOr Compare(const Shape& shape, HloOpcode opcode, + LiteralSlice lhs_literal, + LiteralSlice rhs_literal) { std::function compare_op; switch (opcode) { case HloOpcode::kEq: @@ -126,11 +125,12 @@ StatusOr> Compare( << HloOpcodeString(opcode); } - auto result = absl::make_unique(shape); - TF_RETURN_IF_ERROR(result->Populate([&](ArraySlice multi_index) { - return compare_op(lhs_literal.Get(multi_index), - rhs_literal.Get(multi_index)); - })); + Literal result(shape); + TF_RETURN_IF_ERROR( + result.Populate([&](absl::Span multi_index) { + return compare_op(lhs_literal.Get(multi_index), + rhs_literal.Get(multi_index)); + })); return std::move(result); } @@ -190,11 +190,16 @@ HloEvaluator::HloEvaluator(int64 max_loop_iterations) return Unimplemented( "HloEvaluatorTypedVisitor: unhandled primitive type: OPAQUE."); }); + typed_visitors_[TOKEN] = + absl::make_unique([](HloInstruction*) { + return Unimplemented( + "HloEvaluatorTypedVisitor: unhandled primitive type: TOKEN."); + }); } template -StatusOr> HloEvaluator::Evaluate( - const HloModule& module, ArraySlice arg_literals) { +StatusOr HloEvaluator::Evaluate( + const HloModule& module, absl::Span arg_literals) { XLA_VLOG_LINES(2, "HloEvaluator::Evaluate module:\n" + module.ToString()); evaluated_.clear(); @@ -206,12 +211,23 @@ StatusOr> HloEvaluator::Evaluate( TF_RETURN_IF_ERROR(module.entry_computation()->Accept(this)); return GetEvaluatedLiteralFor(module.entry_computation()->root_instruction()) - .CloneToUnique(); + .Clone(); +} + +template <> +StatusOr HloEvaluator::Evaluate( + const HloModule& module, absl::Span arg_literals) { + std::vector arg_literal_ptrs; + for (const auto& literal_ptr : arg_literals) { + arg_literal_ptrs.push_back(&literal_ptr); + } + return Evaluate(module, arg_literal_ptrs); } template -StatusOr> HloEvaluator::Evaluate( - const HloComputation& computation, ArraySlice arg_literals) { +StatusOr HloEvaluator::Evaluate( + const HloComputation& computation, + absl::Span arg_literals) { CHECK(computation.parent() != nullptr); XLA_VLOG_LINES( 2, "HloEvaluator::Evaluate computation:\n" + computation.ToString()); @@ -223,12 +239,22 @@ StatusOr> HloEvaluator::Evaluate( } TF_RETURN_IF_ERROR(computation.Accept(this)); - return GetEvaluatedLiteralFor(computation.root_instruction()).CloneToUnique(); + return GetEvaluatedLiteralFor(computation.root_instruction()).Clone(); +} + +template <> +StatusOr HloEvaluator::Evaluate( + const HloComputation& computation, absl::Span arg_literals) { + std::vector arg_literal_ptrs; + for (const auto& literal_ptr : arg_literals) { + arg_literal_ptrs.push_back(&literal_ptr); + } + return Evaluate(computation, arg_literal_ptrs); } template -StatusOr> HloEvaluator::Evaluate( - HloInstruction* instruction, ArraySlice arg_literals) { +StatusOr HloEvaluator::Evaluate( + HloInstruction* instruction, absl::Span arg_literals) { TF_RET_CHECK(hlo_query::AllOperandsAreParametersOrConstants(*instruction)); evaluated_.clear(); @@ -246,18 +272,27 @@ StatusOr> HloEvaluator::Evaluate( << input_literal->ToString(); TF_RET_CHECK(ShapeUtil::Equal(operand->shape(), input_literal->shape())); - evaluated_[operand] = input_literal->CloneToUnique(); + evaluated_[operand] = input_literal->Clone(); } } TF_RETURN_IF_ERROR(Preprocess(instruction)); TF_RETURN_IF_ERROR(instruction->Visit(this)); TF_RETURN_IF_ERROR(Postprocess(instruction)); - return GetEvaluatedLiteralFor(instruction).CloneToUnique(); + return GetEvaluatedLiteralFor(instruction).Clone(); +} + +template <> +StatusOr HloEvaluator::Evaluate( + HloInstruction* instruction, absl::Span arg_literals) { + std::vector arg_literal_ptrs; + for (const auto& literal : arg_literals) { + arg_literal_ptrs.push_back(&literal); + } + return Evaluate(instruction, arg_literal_ptrs); } -StatusOr> HloEvaluator::Evaluate( - HloInstruction* instruction) { +StatusOr HloEvaluator::Evaluate(HloInstruction* instruction) { if (instruction->opcode() == HloOpcode::kParameter) { return tensorflow::errors::FailedPrecondition( "Cannot evaluate a parameter."); @@ -273,21 +308,22 @@ StatusOr> HloEvaluator::Evaluate( TF_RETURN_IF_ERROR(Preprocess(instruction)); TF_RETURN_IF_ERROR(instruction->Visit(this)); TF_RETURN_IF_ERROR(Postprocess(instruction)); - return GetEvaluatedLiteralFor(instruction).CloneToUnique(); + return GetEvaluatedLiteralFor(instruction).Clone(); } -std::unique_ptr HloEvaluator::TryEvaluate( - HloInstruction* instruction) { +bool HloEvaluator::TryEvaluate(HloInstruction* instruction, Literal* result) { + CHECK(result != nullptr); auto result_or = Evaluate(instruction); if (!result_or.ok()) { VLOG(1) << "TryEvaluate failed:" << result_or.status(); - return nullptr; + return false; } - return result_or.ConsumeValueOrDie(); + *result = result_or.ConsumeValueOrDie(); + return true; } -StatusOr> HloEvaluator::EvaluateWithSubstitutions( +StatusOr HloEvaluator::EvaluateWithSubstitutions( const HloInstruction* instruction, const std::unordered_map& substitutions) { @@ -298,7 +334,7 @@ StatusOr> HloEvaluator::EvaluateWithSubstitutions( owned_operands.push_back(operand->Clone()); } else { owned_operands.push_back( - HloInstruction::CreateConstant(it->second->CloneToUnique())); + HloInstruction::CreateConstant(it->second->Clone())); } } @@ -315,12 +351,12 @@ StatusOr> HloEvaluator::EvaluateWithSubstitutions( return result; } -StatusOr> HloEvaluator::EvaluateElementwiseBinaryOp( +StatusOr HloEvaluator::EvaluateElementwiseBinaryOp( HloOpcode opcode, const Literal& lhs, const Literal& rhs) { std::unique_ptr lhs_instr = - HloInstruction::CreateConstant(lhs.CloneToUnique()); + HloInstruction::CreateConstant(lhs.Clone()); std::unique_ptr rhs_instr = - HloInstruction::CreateConstant(rhs.CloneToUnique()); + HloInstruction::CreateConstant(rhs.Clone()); std::unique_ptr cloned_instruction = HloInstruction::CreateBinary(lhs.shape(), opcode, lhs_instr.get(), @@ -330,10 +366,10 @@ StatusOr> HloEvaluator::EvaluateElementwiseBinaryOp( return result; } -StatusOr> HloEvaluator::EvaluateElementwiseUnaryOp( +StatusOr HloEvaluator::EvaluateElementwiseUnaryOp( HloOpcode opcode, const Literal& operand) { std::unique_ptr operand_instr = - HloInstruction::CreateConstant(operand.CloneToUnique()); + HloInstruction::CreateConstant(operand.Clone()); std::unique_ptr cloned_instruction = HloInstruction::CreateUnary(operand.shape(), opcode, operand_instr.get()); @@ -342,13 +378,14 @@ StatusOr> HloEvaluator::EvaluateElementwiseUnaryOp( return result; } -StatusOr> HloEvaluator::EvaluateDotOp( - const DotDimensionNumbers& dim_numbers, const Literal& lhs, +StatusOr HloEvaluator::EvaluateDotOp( + const DotDimensionNumbers& dim_numbers, + const PrecisionConfig& precision_config, const Literal& lhs, const Literal& rhs) { std::unique_ptr lhs_instr = - HloInstruction::CreateConstant(lhs.CloneToUnique()); + HloInstruction::CreateConstant(lhs.Clone()); std::unique_ptr rhs_instr = - HloInstruction::CreateConstant(rhs.CloneToUnique()); + HloInstruction::CreateConstant(rhs.Clone()); TF_ASSIGN_OR_RETURN( Shape dot_shape, @@ -356,7 +393,7 @@ StatusOr> HloEvaluator::EvaluateDotOp( std::unique_ptr cloned_instruction = HloInstruction::CreateDot(dot_shape, lhs_instr.get(), rhs_instr.get(), - dim_numbers); + dim_numbers, precision_config); return Evaluate(cloned_instruction.get()); } @@ -369,7 +406,7 @@ Status HloEvaluator::HandleParameter(HloInstruction* parameter) { << ", but input literal shape is: " << ShapeUtil::HumanString(input_literal->shape()); - evaluated_[parameter] = input_literal->CloneToUnique(); + evaluated_[parameter] = input_literal->Clone(); return Status::OK(); } @@ -390,7 +427,7 @@ Status HloEvaluator::HandleTranspose(HloInstruction* transpose) { } Status HloEvaluator::HandleConcatenate(HloInstruction* concatenate) { - ArraySlice operands(concatenate->operands()); + absl::Span operands(concatenate->operands()); // The result concatenate dimension is going to be the sum of all // concatenate dimensions of the operands taking part of the operation. const Shape& reference_shape = operands[0]->shape(); @@ -419,7 +456,7 @@ Status HloEvaluator::HandleConcatenate(HloInstruction* concatenate) { for (auto operand : operands) { const Shape& operand_shape = operand->shape(); - TF_RETURN_IF_ERROR(result_literal->CopySliceFrom( + TF_RETURN_IF_ERROR(result_literal.CopySliceFrom( GetEvaluatedLiteralFor(operand), source_indices, dest_indices, AsInt64Slice(operand_shape.dimensions()))); dest_indices[concat_dim] += @@ -435,7 +472,7 @@ Status HloEvaluator::HandleIsFinite(HloInstruction* is_finite) { if (!ShapeUtil::ElementIsFloating(operand->shape())) { return InvalidArgument( "expected element type in shape to be float for IsFinite op, got: %s", - PrimitiveType_Name(operand->shape().element_type()).c_str()); + PrimitiveType_Name(operand->shape().element_type())); } switch (operand->shape().element_type()) { @@ -465,6 +502,61 @@ Status HloEvaluator::HandleIsFinite(HloInstruction* is_finite) { return Status::OK(); } +Status HloEvaluator::HandleReal(HloInstruction* real) { + auto operand = real->operand(0); + switch (operand->shape().element_type()) { + case BF16: { + auto result_or = ElementWiseUnaryOpImpl( + real, [](bfloat16 elem_operand) { return elem_operand; }, + GetEvaluatedLiteralFor(operand)); + TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or)); + break; + } + case C64: { + auto result_or = ElementWiseUnaryOpImpl( + real, [](complex64 elem_operand) { return std::real(elem_operand); }, + GetEvaluatedLiteralFor(operand)); + TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or)); + break; + } + case F16: { + auto result_or = ElementWiseUnaryOpImpl( + real, [](Eigen::half elem_operand) { return elem_operand; }, + GetEvaluatedLiteralFor(operand)); + TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or)); + break; + } + case F32: { + auto result_or = ElementWiseUnaryOpImpl( + real, [](float elem_operand) { return elem_operand; }, + GetEvaluatedLiteralFor(operand)); + TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or)); + break; + } + case F64: { + auto result_or = ElementWiseUnaryOpImpl( + real, [](double elem_operand) { return elem_operand; }, + GetEvaluatedLiteralFor(operand)); + TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or)); + break; + } + default: + LOG(FATAL) << "HandleReal: unknown/unhandled primitive type: " + << PrimitiveType_Name(operand->shape().element_type()); + } + + return Status::OK(); +} + +Status HloEvaluator::HandleImag(HloInstruction* imag) { + auto result_or = ElementWiseUnaryOpImpl( + imag, [](complex64 elem_operand) { return std::imag(elem_operand); }, + GetEvaluatedLiteralFor(imag->operand(0))); + + TF_ASSIGN_OR_RETURN(evaluated_[imag], std::move(result_or)); + return Status::OK(); +} + Status HloEvaluator::HandleCompare(HloInstruction* compare) { HloOpcode opcode = compare->opcode(); auto lhs = compare->operand(0); @@ -476,9 +568,9 @@ Status HloEvaluator::HandleCompare(HloInstruction* compare) { return Unimplemented( "Implicit broadcasting is currently unsupported in HLO evaluator " "Shape Mismatch: %s vs %s vs %s", - ShapeUtil::HumanString(compare->shape()).c_str(), - ShapeUtil::HumanString(lhs->shape()).c_str(), - ShapeUtil::HumanString(rhs->shape()).c_str()); + ShapeUtil::HumanString(compare->shape()), + ShapeUtil::HumanString(lhs->shape()), + ShapeUtil::HumanString(rhs->shape())); } TF_RET_CHECK(lhs->shape().element_type() == rhs->shape().element_type()); @@ -588,7 +680,7 @@ ShapeUtil::IndexIterationSpace IterationSpaceForOutputBatchIndices( // Return an ShapeUtil::IndexIterationSpace that iterates over the output slice // dimensions while keeping the rest of the output dimensions clamped to 0. ShapeUtil::IndexIterationSpace IterationSpaceForOutputOffsetIndices( - int64 output_rank, ArraySlice slice_sizes, + int64 output_rank, absl::Span slice_sizes, const GatherDimensionNumbers& dim_numbers) { std::vector index_base(output_rank, 0); std::vector index_count(output_rank, 1); @@ -660,12 +752,13 @@ class OutputBatchIndexToInputIndex { // index_vector_index_ and index_vector on every invocation, we reuse the // same storage for all invocations. // - // This returns an arrayslice into memory owned by the class. - StatusOr> operator()(ArraySlice output_index) { + // This returns a Span into memory owned by the class. + StatusOr> operator()( + absl::Span output_index) { PropagateOutputIndexGatherDimsToIndexVectorIndex(output_index); TF_RETURN_IF_ERROR(FetchIndexVector()); PropagateIndexVectorToInputIndex(); - return ArraySlice(input_index_); + return absl::Span(input_index_); } private: @@ -674,7 +767,7 @@ class OutputBatchIndexToInputIndex { // update the dim_numbers.index_vector_dim() dimension -- that's the dimension // we iterate over in FetchIndexVector. void PropagateOutputIndexGatherDimsToIndexVectorIndex( - ArraySlice output_index) { + absl::Span output_index) { int64 index_vector_index_i = 0; for (int64 i = 0, e = output_index.size(); i < e; i++) { if (!output_dim_is_batch_dims_[i]) { @@ -729,7 +822,7 @@ class OutputBatchIndexToInputIndex { // The index vector fetched from start_indices_. std::vector index_vector_; - // The result computed by this functor. operator() returns an ArraySlice into + // The result computed by this functor. operator() returns a Span into // this vector. std::vector input_index_; @@ -778,10 +871,11 @@ class OutputOffsetIndexToInputIndex { // gather input index on every invocation we reuse the same storage for the // result (input_index_), mutating it in place. // - // This returns an arrayslice into memory owned by the class. - StatusOr> operator()(ArraySlice output_index) { + // This returns a Span into memory owned by the class. + StatusOr> operator()( + absl::Span output_index) { PropagateOutputIndexWindowDimsToInputIndex(output_index); - return ArraySlice(input_index_); + return absl::Span(input_index_); } // Returns for a given 'input_dim' the corresponding output dimension index, @@ -794,7 +888,7 @@ class OutputOffsetIndexToInputIndex { // Propagates window dimensions from the output index to input_index_ by // mutating input_index_ in place. void PropagateOutputIndexWindowDimsToInputIndex( - ArraySlice output_index) { + absl::Span output_index) { for (int64 i = 0, e = input_index_.size(); i < e; i++) { if (input_dim_value_to_output_index_[i] != -1) { input_index_[i] = output_index[input_dim_value_to_output_index_[i]]; @@ -810,7 +904,7 @@ class OutputOffsetIndexToInputIndex { // PropagateOutputIndexWindowDimsToInputIndex. std::vector input_dim_value_to_output_index_; - // The result computed by this functor. operator() returns an ArraySlice into + // The result computed by this functor. operator() returns a Span into // this vector. std::vector input_index_; }; @@ -820,7 +914,7 @@ class OutputOffsetIndexToInputIndex { // there is one) to `reshaped_start_indices`. static StatusOr> ReshapedGatherIndices( int64 index_vector_dim, const Literal& start_indices, - std::unique_ptr* reshaped_start_indices) { + Literal* reshaped_start_indices) { if (start_indices.shape().dimensions_size() != index_vector_dim) { return std::cref(start_indices); } @@ -830,16 +924,16 @@ static StatusOr> ReshapedGatherIndices( new_shape.push_back(1); TF_ASSIGN_OR_RETURN(*reshaped_start_indices, start_indices.Reshape(new_shape)); - return std::cref(**reshaped_start_indices); + return std::cref(*reshaped_start_indices); } Status HloEvaluator::HandleGather(HloInstruction* gather) { - std::unique_ptr result = Literal::CreateFromShape(gather->shape()); + Literal result = Literal::CreateFromShape(gather->shape()); const Shape& shape = gather->shape(); const GatherDimensionNumbers& dim_numbers = gather->gather_dimension_numbers(); const Literal& operand = GetEvaluatedLiteralFor(gather->operand(0)); - std::unique_ptr reshaped_start_indices; + Literal reshaped_start_indices; TF_ASSIGN_OR_RETURN( const Literal& start_indices, ReshapedGatherIndices(dim_numbers.index_vector_dim(), @@ -872,11 +966,11 @@ Status HloEvaluator::HandleGather(HloInstruction* gather) { const Shape& operand_shape = operand.shape(); auto gather_inner_loop_body = - [&](ArraySlice output_window_index, - ArraySlice input_gather_index, - ArraySlice output_gather_index) -> StatusOr { + [&](absl::Span output_window_index, + absl::Span input_gather_index, + absl::Span output_gather_index) -> StatusOr { TF_ASSIGN_OR_RETURN( - ArraySlice input_window_index, + absl::Span input_window_index, output_offset_index_to_input_index(output_window_index)); for (int i = 0, e = output_index.size(); i < e; i++) { output_index[i] = output_gather_index[i] + output_window_index[i]; @@ -904,13 +998,13 @@ Status HloEvaluator::HandleGather(HloInstruction* gather) { DCHECK_LT(input_index[i], operand_shape.dimensions(i)); } TF_RETURN_IF_ERROR( - result->CopyElementFrom(operand, input_index, output_index)); + result.CopyElementFrom(operand, input_index, output_index)); return true; }; auto gather_outer_loop_body = - [&](ArraySlice output_gather_index) -> StatusOr { - TF_ASSIGN_OR_RETURN(ArraySlice input_gather_index, + [&](absl::Span output_gather_index) -> StatusOr { + TF_ASSIGN_OR_RETURN(absl::Span input_gather_index, output_batch_index_to_input_index(output_gather_index)); TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus( shape, offset_indices_iteration_space, @@ -936,8 +1030,14 @@ Status HloEvaluator::HandleBroadcast(HloInstruction* broadcast) { // Checks that operand's dimensions are the same as the broadcast's // dimensions along the dimensions to be broadcasted. for (int64 i = 0; i < broadcast->dimensions().size(); ++i) { - TF_RET_CHECK(broadcast->shape().dimensions(broadcast->dimensions(i)) == - operand.shape().dimensions(i)); + auto operand_dim_size = operand.shape().dimensions(i); + auto broadcast_dim_size = + broadcast->shape().dimensions(broadcast->dimensions(i)); + TF_RET_CHECK(operand_dim_size == broadcast_dim_size) << absl::StreamFormat( + "Operand dimension %d is broadcast to output dimension %d, but the " + "sizes of these two dims do not match (%d vs %d): %s", + i, broadcast->dimensions(i), operand_dim_size, broadcast_dim_size, + broadcast->ToString()); } TF_ASSIGN_OR_RETURN( @@ -967,18 +1067,16 @@ Status HloEvaluator::HandleGetTupleElement(HloInstruction* get_tuple_element) { const Literal& operand_tuple_literal = GetEvaluatedLiteralFor(operand); - evaluated_[get_tuple_element] = absl::make_unique( - ShapeUtil::GetTupleElementShape(operand->shape(), index)); - return evaluated_[get_tuple_element]->CopyFrom(operand_tuple_literal, - /*dest_shape_index=*/{}, - /*src_shape_index=*/{index}); + evaluated_[get_tuple_element] = + Literal(ShapeUtil::GetTupleElementShape(operand->shape(), index)); + return evaluated_[get_tuple_element].CopyFrom(operand_tuple_literal, + /*dest_shape_index=*/{}, + /*src_shape_index=*/{index}); } Status HloEvaluator::HandleCopy(HloInstruction* copy) { TF_RET_CHECK(ShapeUtil::Compatible(copy->shape(), copy->operand(0)->shape())); - - auto result = GetEvaluatedLiteralFor(copy->operand(0)).CloneToUnique(); - evaluated_[copy] = std::move(result); + evaluated_[copy] = GetEvaluatedLiteralFor(copy->operand(0)).Clone(); return Status::OK(); } @@ -994,7 +1092,7 @@ Status HloEvaluator::HandleCall(HloInstruction* call) { } HloEvaluator embedded_evaluator; - std::unique_ptr result = + Literal result = embedded_evaluator.Evaluate(*computation, arg_literals) .ConsumeValueOrDie(); @@ -1026,7 +1124,7 @@ Status HloEvaluator::HandleFusion(HloInstruction* fusion) { } HloEvaluator embedded_evaluator; - std::unique_ptr result = + Literal result = embedded_evaluator .Evaluate(*readded_computation, arg_literals) .ConsumeValueOrDie(); @@ -1046,7 +1144,7 @@ Status HloEvaluator::HandleConditional(HloInstruction* conditional) { auto* false_computation = conditional->false_computation(); HloEvaluator embedded_evaluator; - std::unique_ptr result; + Literal result; if (pred.Get({})) { result = embedded_evaluator .Evaluate(*true_computation, @@ -1071,9 +1169,9 @@ Status HloEvaluator::HandleSelect(HloInstruction* select) { // If predicate is of scalar type, no element-wise selection would be needed. if (ShapeUtil::IsScalar(pred.shape())) { if (pred.Get({})) { - evaluated_[select] = on_true.CloneToUnique(); + evaluated_[select] = on_true.Clone(); } else { - evaluated_[select] = on_false.CloneToUnique(); + evaluated_[select] = on_false.Clone(); } return Status::OK(); } @@ -1087,9 +1185,9 @@ Status HloEvaluator::HandleTupleSelect(HloInstruction* tuple_select) { const auto& on_false = GetEvaluatedLiteralFor(tuple_select->operand(2)); if (pred.Get({})) { - evaluated_[tuple_select] = on_true.CloneToUnique(); + evaluated_[tuple_select] = on_true.Clone(); } else { - evaluated_[tuple_select] = on_false.CloneToUnique(); + evaluated_[tuple_select] = on_false.Clone(); } return Status::OK(); } @@ -1098,23 +1196,23 @@ Status HloEvaluator::HandleWhile(HloInstruction* while_hlo) { HloComputation* cond_comp = while_hlo->while_condition(); HloComputation* body_comp = while_hlo->while_body(); // Initialize the loop carried valued with the input to the While instruction. - auto lcv = GetEvaluatedLiteralFor(while_hlo->operand(0)).CloneToUnique(); + auto lcv = GetEvaluatedLiteralFor(while_hlo->operand(0)).Clone(); bool keep_going = true; int64 iteration_count = 0; HloEvaluator cond_evaluator(max_loop_iterations_); HloEvaluator loop_body_evaluator(max_loop_iterations_); while (keep_going) { if (max_loop_iterations_ >= 0 && iteration_count++ > max_loop_iterations_) { - return InvalidArgument("Loop %s exceeded loop iteration limit (%lld).", - while_hlo->name().c_str(), max_loop_iterations_); + return InvalidArgument("Loop %s exceeded loop iteration limit (%d).", + while_hlo->name(), max_loop_iterations_); } - TF_ASSIGN_OR_RETURN(auto cond_val, cond_evaluator.Evaluate( - *cond_comp, {lcv.get()})); - keep_going = cond_val->GetFirstElement(); + TF_ASSIGN_OR_RETURN(auto cond_val, + cond_evaluator.Evaluate(*cond_comp, {&lcv})); + keep_going = cond_val.GetFirstElement(); if (keep_going) { TF_ASSIGN_OR_RETURN(auto body_val, loop_body_evaluator.Evaluate( - *body_comp, {lcv.get()})); - VLOG(3) << "Loop iteration result: " << body_val->ToString(); + *body_comp, {&lcv})); + VLOG(3) << "Loop iteration result: " << body_val.ToString(); lcv = std::move(body_val); cond_evaluator.ResetVisitStates(); loop_body_evaluator.ResetVisitStates(); @@ -1129,99 +1227,106 @@ Status HloEvaluator::HandleWhile(HloInstruction* while_hlo) { // hoops to make this work. namespace { template -StatusOr> EvaluateSortInternal( - HloInstruction* sort, const Literal& keys_literal, - const Literal& values_literal) { +StatusOr EvaluateSortInternal(HloInstruction* sort, + const Literal& keys_literal, + const Literal& values_literal) { auto rank = ShapeUtil::Rank(keys_literal.shape()); TF_RET_CHECK( ShapeUtil::SameDimensions(keys_literal.shape(), values_literal.shape())) << "Sort keys and values must have the same dimensions"; - TF_RET_CHECK(rank > 0 && rank <= 2) - << "Sort is only supported for rank-1 and rank-2 shapes, rank is: " - << rank; - TF_RET_CHECK(sort->operand_count() == 2) << "Expected key-value sort"; - // We need to sort and array of keys and an array of values, where the + TF_RET_CHECK(sort->operand_count() >= 2) << "Expected key-value sort"; + // We need to sort an array of keys and an array of values, where the // sorted order of the values is determined by the keys. The simplest(?) // way to do this is to go to an array-of-pairs representation, sort the // array using the keys, and then go back to pair-of-arrays. VLOG(3) << "HandleSort keys_literal: " << keys_literal.ToString(); VLOG(3) << "HandleSort values_literal: " << values_literal.ToString(); - auto sort_r1 = [](const Literal& keys_literal, - const Literal& values_literal) { - const auto& keys_data = keys_literal.data(); - const auto& values_data = values_literal.data(); - - using kv_pair = std::pair; - std::vector key_value_vector; - CHECK_EQ(keys_data.size(), values_data.size()); - key_value_vector.reserve(keys_data.size()); - for (int i = 0; i < keys_data.size(); ++i) { - key_value_vector.push_back(std::make_pair(keys_data[i], values_data[i])); - } - std::sort(key_value_vector.begin(), key_value_vector.end(), - [](const kv_pair& a, const kv_pair& b) { - return SafeLess(a.first, b.first); - }); - std::vector result_keys; - std::vector result_values; - for (const auto& key_value : key_value_vector) { - result_keys.push_back(key_value.first); - result_values.push_back(key_value.second); - } - auto result_keys_literal = absl::make_unique(keys_literal.shape()); - result_keys_literal->PopulateR1( - tensorflow::gtl::ArraySlice(result_keys)); - auto result_values_literal = - absl::make_unique(values_literal.shape()); - result_values_literal->PopulateR1( - tensorflow::gtl::ArraySlice(result_values)); - return std::make_pair(std::move(result_keys_literal), - std::move(result_values_literal)); - }; - - std::unique_ptr result_tuple; - if (rank == 1) { - auto result_pair = sort_r1(keys_literal, values_literal); - result_tuple = LiteralUtil::MakeTuple( - {result_pair.first.get(), result_pair.second.get()}); - } else { - // For R2 sort, the desired semantics are to sort each matrix row - // independently. - auto keys_result_literal = absl::make_unique(keys_literal.shape()); - auto values_result_literal = - absl::make_unique(values_literal.shape()); - int64 r1_length = keys_literal.shape().dimensions(1); - for (int64 row = 0; row < keys_literal.shape().dimensions(0); ++row) { - TF_ASSIGN_OR_RETURN(auto keys_r1_slice, - keys_literal.Slice({row, 0}, {row + 1, r1_length}) - ->Reshape({r1_length})); - TF_ASSIGN_OR_RETURN(auto values_r1_slice, - values_literal.Slice({row, 0}, {row + 1, r1_length}) - ->Reshape({r1_length})); - auto r1_result_pair = sort_r1(*keys_r1_slice, *values_r1_slice); - TF_ASSIGN_OR_RETURN(auto sorted_keys, - r1_result_pair.first->Reshape({1, r1_length})); - TF_ASSIGN_OR_RETURN(auto sorted_values, - r1_result_pair.second->Reshape({1, r1_length})); - TF_RETURN_IF_ERROR(keys_result_literal->CopySliceFrom( - *sorted_keys, {0, 0}, {row, 0}, {1, r1_length})); - TF_RETURN_IF_ERROR(values_result_literal->CopySliceFrom( - *sorted_values, {0, 0}, {row, 0}, {1, r1_length})); - } - result_tuple = LiteralUtil::MakeTuple( - {keys_result_literal.get(), values_result_literal.get()}); + if (rank == 0) { + // Nothing to sort. + return LiteralUtil::MakeTuple({&keys_literal, &values_literal}); } - VLOG(3) << "HandleSort result_tuple: " << result_tuple->ToString(); + Literal keys_result_literal(keys_literal.shape()); + Literal values_result_literal(values_literal.shape()); + std::vector zero_base(rank, 0); + std::vector increment(rank, 1); + int64 sort_dim = sort->dimensions(0); + int64 sort_dim_elements = keys_literal.shape().dimensions(sort_dim); + increment[sort_dim] = sort_dim_elements; + // Iterate through each dimension except 'sort_dim'. + TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus( + keys_literal.shape(), zero_base, + AsInt64Slice(keys_literal.shape().dimensions()), increment, + [&](absl::Span indices) -> StatusOr { + // Extract a slice from the keys and values literals that correspond to + // exactly the row in dimension 'sort_dim'. + std::vector limit_indices(indices.begin(), indices.end()); + std::for_each(limit_indices.begin(), limit_indices.end(), + [](int64& index) { ++index; }); + limit_indices[sort_dim] = sort_dim_elements; + TF_ASSIGN_OR_RETURN(auto keys_to_sort, + keys_literal.Slice(indices, limit_indices) + .Reshape({sort_dim_elements})); + const auto& keys_data = keys_to_sort.data(); + TF_ASSIGN_OR_RETURN(auto values_to_sort, + values_literal.Slice(indices, limit_indices) + .Reshape({sort_dim_elements})); + const auto& values_data = values_to_sort.data(); + using kv_pair = std::pair; + std::vector key_value_vector; + key_value_vector.reserve(keys_data.size()); + for (int i = 0; i < keys_data.size(); ++i) { + key_value_vector.push_back( + std::make_pair(keys_data[i], values_data[i])); + } + std::sort(key_value_vector.begin(), key_value_vector.end(), + [](const kv_pair& a, const kv_pair& b) { + return SafeLess(a.first, b.first); + }); + std::vector result_keys; + // We use a InlinedVector here because we need to convert it to an + // absl::Span later, and this would not work with std::vector. + absl::InlinedVector result_values; + for (const auto& key_value : key_value_vector) { + result_keys.push_back(key_value.first); + result_values.push_back(key_value.second); + } + Literal sorted_keys(ShapeUtil::MakeShape( + keys_literal.shape().element_type(), {sort_dim_elements})); + sorted_keys.PopulateR1(absl::Span(result_keys)); + Literal sorted_values(ShapeUtil::MakeShape( + values_literal.shape().element_type(), {sort_dim_elements})); + sorted_values.PopulateR1(absl::Span(result_values)); + std::vector slice_dimensions(rank, 1); + slice_dimensions[sort_dim] = sort_dim_elements; + std::vector start_indices(rank, 0); + TF_ASSIGN_OR_RETURN(auto sorted_keys_reshaped, + sorted_keys.Reshape(slice_dimensions)); + TF_RETURN_IF_ERROR(keys_result_literal.CopySliceFrom( + sorted_keys_reshaped, start_indices, indices, slice_dimensions)); + TF_ASSIGN_OR_RETURN(auto sorted_values_reshaped, + sorted_values.Reshape(slice_dimensions)); + TF_RETURN_IF_ERROR(values_result_literal.CopySliceFrom( + sorted_values_reshaped, start_indices, indices, slice_dimensions)); + return true; + })); + + Literal result_tuple; + result_tuple = + LiteralUtil::MakeTuple({&keys_result_literal, &values_result_literal}); + VLOG(3) << "HandleSort result_tuple: " << result_tuple.ToString(); return std::move(result_tuple); } template -StatusOr> EvaluateSortCurried( - HloInstruction* sort, const Literal& keys_literal, - const Literal& values_literal) { - switch (sort->operand(1)->shape().element_type()) { +StatusOr EvaluateSortCurried(HloInstruction* sort, + const Literal& keys_literal, + const Literal& values_literal) { + switch (values_literal.shape().element_type()) { + case PRED: + return EvaluateSortInternal(sort, keys_literal, + values_literal); case F32: return EvaluateSortInternal(sort, keys_literal, values_literal); @@ -1239,9 +1344,9 @@ StatusOr> EvaluateSortCurried( } } -StatusOr> EvaluateSort(HloInstruction* sort, - const Literal& keys_literal, - const Literal& values_literal) { +StatusOr EvaluateSort(HloInstruction* sort, + const Literal& keys_literal, + const Literal& values_literal) { switch (sort->operand(0)->shape().element_type()) { case F32: return EvaluateSortCurried(sort, keys_literal, values_literal); @@ -1258,26 +1363,43 @@ StatusOr> EvaluateSort(HloInstruction* sort, } // namespace Status HloEvaluator::HandleSort(HloInstruction* sort) { - const int64 sort_dim = sort->dimensions(0); - const int64 rank = ShapeUtil::Rank(sort->operand(0)->shape()); - if (sort_dim != rank - 1) { - return Unimplemented( - "Trying to support along dimension %lld, which is not the last " - "dimension", - sort_dim); - } - if (!ShapeUtil::IsTuple(sort->shape())) { return DefaultAction(sort); } else { - auto result = EvaluateSort(sort, GetEvaluatedLiteralFor(sort->operand(0)), - GetEvaluatedLiteralFor(sort->operand(1))); - if (result.ok()) { - evaluated_[sort] = std::move(result.ValueOrDie()); - return Status::OK(); - } else { - return result.status(); + // This is a really stupid work-around for the fact it's hard to support a + // multi-value sort directly, due to the fact we need to template the + // evaluation function on all of the value types. + std::vector sort_results_backing; + for (int64 i = 0; i < sort->operand_count(); ++i) { + auto result = EvaluateSort(sort, GetEvaluatedLiteralFor(sort->operand(0)), + GetEvaluatedLiteralFor(sort->operand(i))); + if (!result.ok()) { + return result.status(); + } + sort_results_backing.push_back( + std::move(result.ValueOrDie().DecomposeTuple()[1])); + } + std::vector sort_results; + absl::c_transform(sort_results_backing, std::back_inserter(sort_results), + [](const Literal& literal) { return &literal; }); + evaluated_[sort] = LiteralUtil::MakeTuple(sort_results); + return Status::OK(); + } +} + +Status HloEvaluator::HandleReduce(HloInstruction* reduce) { + if (!ShapeUtil::IsTuple(reduce->shape())) { + return DefaultAction(reduce); + } else { + auto first_element_type = reduce->shape().tuple_shapes(0).element_type(); + for (const auto& tuple_shape : reduce->shape().tuple_shapes()) { + if (tuple_shape.element_type() != first_element_type) { + return Unimplemented( + "Reduce with several outputs that have mixed element types is " + "unsupported"); + } } + return reduce->Visit(typed_visitors_[first_element_type].get()); } } @@ -1289,32 +1411,25 @@ Status HloEvaluator::Preprocess(HloInstruction* hlo) { Status HloEvaluator::Postprocess(HloInstruction* hlo) { VLOG(2) << "Finished visiting " << hlo->ToString() << "; evaluated value is: " << GetEvaluatedLiteralFor(hlo).ToString(); + // Out of convenience the literal may have been produced with a different + // layout. Relayout as indicated by the HLO instruction. + if (!LayoutUtil::LayoutsInShapesEqual(GetEvaluatedLiteralFor(hlo).shape(), + hlo->shape())) { + evaluated_.at(hlo) = evaluated_.at(hlo).Relayout(hlo->shape()); + } return Status::OK(); } // Explicit instantiation of templatized Evaluate* methods. // -template StatusOr> -HloEvaluator::Evaluate(const HloModule& module, - ArraySlice arg_literals); -template StatusOr> -HloEvaluator::Evaluate>( - const HloModule& module, ArraySlice> arg_literals); - -template StatusOr> -HloEvaluator::Evaluate(const HloComputation& computation, - ArraySlice arg_literals); -template StatusOr> -HloEvaluator::Evaluate>( +template StatusOr HloEvaluator::Evaluate( + const HloModule& module, absl::Span arg_literals); + +template StatusOr HloEvaluator::Evaluate( const HloComputation& computation, - ArraySlice> arg_literals); - -template StatusOr> -HloEvaluator::Evaluate(HloInstruction* instruction, - ArraySlice arg_literals); -template StatusOr> -HloEvaluator::Evaluate>( - HloInstruction* instruction, - ArraySlice> arg_literals); + absl::Span arg_literals); + +template StatusOr HloEvaluator::Evaluate( + HloInstruction* instruction, absl::Span arg_literals); } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h index 7588916de5068416410daf1a71a0bbad56f3ef0b..07f8d0aad4af0b07303b4e485b3630cc75bcb519 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator.h @@ -18,7 +18,9 @@ limitations under the License. #include +#include "absl/container/node_hash_map.h" #include "absl/memory/memory.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -27,8 +29,6 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/platform/macros.h" namespace xla { @@ -47,12 +47,11 @@ class HloEvaluator : public DfsHloVisitorWithDefault { // Precondition: The indices of arg_literals correspond to the parameter // numbers of the HLO parameters in the computation. See comment below for an // example. - // `LiteralPtr` accepts either std::unique_ptr or const Literal* + // `LiteralPtr` accepts either Literal or const Literal* // type. template - StatusOr> Evaluate( - const HloModule& module, - tensorflow::gtl::ArraySlice arg_literals); + StatusOr Evaluate(const HloModule& module, + absl::Span arg_literals); // Evaluates an HLO computation and an array of pointers to literals. // Returns the evaluated result as a literal if successful. @@ -70,12 +69,11 @@ class HloEvaluator : public DfsHloVisitorWithDefault { // where Parameter0 has parameter_number 0 and Parameter1 has parameter_number // 1 in this computation. The input literals array will then have its first // literal map to Parameter0 and the second map to Parameter1. - // `LiteralPtr` accepts either std::unique_ptr or const Literal* + // `LiteralPtr` accepts either Literal or const Literal* // type. template - StatusOr> Evaluate( - const HloComputation& computation, - tensorflow::gtl::ArraySlice arg_literals); + StatusOr Evaluate(const HloComputation& computation, + absl::Span arg_literals); // Evaluates a single HLO instruction and an array of pointers to literals. // Return the evaluated result as literal if successful. @@ -83,42 +81,43 @@ class HloEvaluator : public DfsHloVisitorWithDefault { // 1. argument literals correspond to the input instruction's parameters in // their post-ordering. // 2. the instruction's operands must be of either Parameter or Constant type. - // `LiteralPtr` accepts either std::unique_ptr or const Literal* + // `LiteralPtr` accepts either Literal or const Literal* // type. template - StatusOr> Evaluate( - HloInstruction* instruction, - tensorflow::gtl::ArraySlice arg_literals); + StatusOr Evaluate(HloInstruction* instruction, + absl::Span arg_literals); // Evaluates a single HLO instruction with constant operands. // Returns the evaluated result as literal if successful. // Precondition: // 1. all operands of the input instruction are constants. // 2. the instruction is not a Parameter operation. - StatusOr> Evaluate(HloInstruction* instruction); + StatusOr Evaluate(HloInstruction* instruction); - // Same as Evaluate, except returning nullptr on error. - std::unique_ptr TryEvaluate(HloInstruction* instruction); + // Same as Evaluate, except returning false on error and accepts an output + // pointer. + bool TryEvaluate(HloInstruction* instruction, Literal* result); // Evaluates a single HLO instruction, substituting the given literals for // some of the instruction's operands. // // For example, given instruction = op(A, B, C) and the map // {A = x, C = y}, this evaluates op(x, B, y). - StatusOr> EvaluateWithSubstitutions( + StatusOr EvaluateWithSubstitutions( const HloInstruction* instruction, const std::unordered_map& substitutions); - StatusOr> EvaluateElementwiseBinaryOp( - HloOpcode opcode, const Literal& lhs, const Literal& rhs); + StatusOr EvaluateElementwiseBinaryOp(HloOpcode opcode, + const Literal& lhs, + const Literal& rhs); - StatusOr> EvaluateElementwiseUnaryOp( - HloOpcode opcode, const Literal& operand); + StatusOr EvaluateElementwiseUnaryOp(HloOpcode opcode, + const Literal& operand); - StatusOr> EvaluateDotOp( - const DotDimensionNumbers& dim_numbers, const Literal& lhs, - const Literal& rhs); + StatusOr EvaluateDotOp(const DotDimensionNumbers& dim_numbers, + const PrecisionConfig& precision_config, + const Literal& lhs, const Literal& rhs); protected: // Make HloEvaluatorTypedVisitor a friend because it is logically part of this @@ -135,7 +134,7 @@ class HloEvaluator : public DfsHloVisitorWithDefault { // Wraps around instruction handling to infer types before dispatching to // the corresponding typed Visitor. Status DefaultAction(HloInstruction* hlo) override { - return hlo->Visit(typed_visitors_.at(hlo->shape().element_type()).get()); + return hlo->Visit(typed_visitors_[hlo->shape().element_type()].get()); } Status Preprocess(HloInstruction* hlo) override; @@ -185,6 +184,12 @@ class HloEvaluator : public DfsHloVisitorWithDefault { Status HandleSort(HloInstruction* sort) override; + Status HandleReal(HloInstruction* real) override; + + Status HandleImag(HloInstruction* imag) override; + + Status HandleReduce(HloInstruction* reduce) override; + // Returns the already-evaluated literal result for the instruction. // A Constant instruction is considered evaluated and its literal will be // returned directly without looking up the cache. @@ -196,7 +201,7 @@ class HloEvaluator : public DfsHloVisitorWithDefault { auto it = evaluated_.find(hlo); CHECK(it != evaluated_.end()) << "could not find evaluated value for: " << hlo->ToString(); - return *(it->second); + return it->second; } // Tracks the HLO instruction and its evaluated literal result. @@ -204,12 +209,13 @@ class HloEvaluator : public DfsHloVisitorWithDefault { // that are no longer a parent for any other subsequent instruction in // post-orderring. // Must be cleared for each evaluation. - tensorflow::gtl::FlatMap> - evaluated_; + // Storing Literal in place require the container to have pointer stability so + // we cannot use flat_hash_map any more. + absl::node_hash_map evaluated_; private: template - static StatusOr> ElementWiseUnaryOpImpl( + static StatusOr ElementWiseUnaryOpImpl( HloInstruction* instruction, const std::function& unary_op, const Literal& operand_literal) { @@ -222,25 +228,20 @@ class HloEvaluator : public DfsHloVisitorWithDefault { return Unimplemented( "Implicit broadcasting is currently unsupported in HLO evaluator " "Shape Mismatch: %s vs %s", - ShapeUtil::HumanString(shape).c_str(), - ShapeUtil::HumanString(operand->shape()).c_str()); + ShapeUtil::HumanString(shape), + ShapeUtil::HumanString(operand->shape())); } - auto result = absl::make_unique(shape); - TF_RETURN_IF_ERROR(result->Populate( - [&](tensorflow::gtl::ArraySlice multi_index) { + Literal result(shape); + TF_RETURN_IF_ERROR( + result.Populate([&](absl::Span multi_index) { return unary_op(operand_literal.Get(multi_index)); })); return std::move(result); } // Map from a primitive type to its associated (templated) DfsHloVisitor. - // Note: the hash function here is only needed because current gcc std::hash - // does not specialize for enum types. This should however be fixed in the - // future: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=60970#c5 - tensorflow::gtl::FlatMap, - std::hash> - typed_visitors_; + std::unique_ptr typed_visitors_[PrimitiveType_ARRAYSIZE]; // Caches pointers to input literals, assuming they are in post-order. // Literals are not owned by this class, and they must outlive the lifetime of diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc index c3af15c6a88e42d0339fddcccd7dae7c6b62fb52..608a42bb60702aa075daca39535ca1672dcc5467 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc @@ -52,15 +52,11 @@ static std::array use_bf16_params{true, false}; class HloEvaluatorTest : public ::testing::WithParamInterface, public HloVerifiedTestBase { protected: - HloEvaluatorTest() - : HloVerifiedTestBase(/*layout_sensitive=*/false, - /*allow_mixed_precision=*/false), - use_bfloat16_(GetParam()) { + HloEvaluatorTest() : HloVerifiedTestBase(), use_bfloat16_(GetParam()) { evaluator_ = absl::make_unique(); } - std::unique_ptr Evaluate( - tensorflow::gtl::ArraySlice arg_literals = {}) { + Literal Evaluate(absl::Span arg_literals = {}) { if (use_bfloat16_) { // In BF16 mode, we convert all F32 type to BF16 and evaluate the module. auto type_converter = HloElementTypeConverter(F32, BF16); @@ -70,41 +66,53 @@ class HloEvaluatorTest : public ::testing::WithParamInterface, .ConsumeValueOrDie(); } + // Evaluate function that takes in a local module instead of using module_ + // that is in HloVerifiedTestBase. Once module_ in HloVerifiedTestBase is + // removed, this should be the default Evaluate function. + Literal EvaluateWithModule( + HloModule* module, absl::Span arg_literals = {}) { + if (use_bfloat16_) { + // In BF16 mode, we convert all F32 type to BF16 and evaluate the module. + auto type_converter = HloElementTypeConverter(F32, BF16); + type_converter.Run(module).ValueOrDie(); + } + return evaluator_->Evaluate(*module->entry_computation(), arg_literals) + .ConsumeValueOrDie(); + } + std::unique_ptr evaluator_; - void TestUnaryOp(HloOpcode opcode, std::unique_ptr expected, - std::unique_ptr input, float aabs = 0) { + void TestUnaryOp(HloOpcode opcode, Literal expected, Literal input, + float aabs = 0) { HloComputation::Builder b(TestName()); auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(input))); - b.AddInstruction( - HloInstruction::CreateUnary(expected->shape(), opcode, c1)); + b.AddInstruction(HloInstruction::CreateUnary(expected.shape(), opcode, c1)); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); - auto element_type = expected->shape().element_type(); + auto element_type = expected.shape().element_type(); if (element_type == F32 || element_type == F64) { ErrorSpec error(aabs); - EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, error)); + EXPECT_TRUE(LiteralTestUtil::Near(expected, result, error)); } else { - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } } - void TestBinaryOp(HloOpcode opcode, std::unique_ptr expected, - std::unique_ptr lhs, - std::unique_ptr rhs) { + void TestBinaryOp(HloOpcode opcode, Literal expected, Literal lhs, + Literal rhs) { HloComputation::Builder b(TestName()); auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs))); auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs))); b.AddInstruction( - HloInstruction::CreateBinary(expected->shape(), opcode, c1, c2)); + HloInstruction::CreateBinary(expected.shape(), opcode, c1, c2)); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } bool use_bfloat16_; @@ -120,7 +128,7 @@ TEST_P(HloEvaluatorTest, DoesClamp) { auto value = LiteralUtil::CreateR2({{0.f, 5.f}, {0.f, 4.f}}); auto high = LiteralUtil::CreateR2({{2.f, 4.f}, {4.f, 4.f}}); - Shape shape = low->shape(); + Shape shape = low.shape(); HloComputation::Builder b(TestName()); auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(low))); auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(value))); @@ -129,11 +137,11 @@ TEST_P(HloEvaluatorTest, DoesClamp) { HloInstruction::CreateTernary(shape, HloOpcode::kClamp, c1, c2, c3)); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); auto expected = LiteralUtil::CreateR2({{0, 4}, {2, 4}}); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, DISABLED_DoesClampSpecialBroadcast) { @@ -141,7 +149,7 @@ TEST_P(HloEvaluatorTest, DISABLED_DoesClampSpecialBroadcast) { auto value = LiteralUtil::CreateR2({{-1.f, 0.f}, {1.f, 2.f}}); auto high = LiteralUtil::CreateR0(1.f); - Shape shape = value->shape(); + Shape shape = value.shape(); HloComputation::Builder b(TestName()); auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(low))); auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(value))); @@ -150,11 +158,11 @@ TEST_P(HloEvaluatorTest, DISABLED_DoesClampSpecialBroadcast) { HloInstruction::CreateTernary(shape, HloOpcode::kClamp, c1, c2, c3)); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); auto expected = LiteralUtil::CreateR2({{0, 0}, {1, 1}}); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } // Verifies that HloEvaluator evaluates a HLO instruction that performs select @@ -164,7 +172,7 @@ TEST_P(HloEvaluatorTest, DoesSelect) { auto on_true = LiteralUtil::CreateR2({{2.f, 4.f}, {4.f, 4.f}}); auto on_false = LiteralUtil::CreateR2({{0.f, 5.f}, {0.f, 4.f}}); - Shape shape = on_true->shape(); + Shape shape = on_true.shape(); HloComputation::Builder b(TestName()); auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(pred))); auto c2 = @@ -175,11 +183,11 @@ TEST_P(HloEvaluatorTest, DoesSelect) { HloInstruction::CreateTernary(shape, HloOpcode::kSelect, c1, c2, c3)); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate({}); + Literal result = Evaluate({}); auto expected = LiteralUtil::CreateR2({{2, 5}, {0, 4}}); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } // Verifies that HloEvaluator evaluates a HLO instruction that performs @@ -298,7 +306,7 @@ TEST_P(HloEvaluatorTest, DoesTraverseInstructions) { auto lhs = LiteralUtil::CreateR2({{1, 0}, {-100, 4}}); auto rhs = LiteralUtil::CreateR2({{2, 4}, {4, 4}}); auto rhs2 = LiteralUtil::CreateR2({{1, -20}, {-100, 4}}); - std::vector args = {lhs.get(), rhs.get(), rhs2.get()}; + std::vector args = {&lhs, &rhs, &rhs2}; Shape shape = ShapeUtil::MakeShape(S64, {2, 2}); @@ -316,11 +324,11 @@ TEST_P(HloEvaluatorTest, DoesTraverseInstructions) { lhs_instruction, param_rhs2)); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(args); + Literal result = Evaluate(args); auto expected = LiteralUtil::CreateR2({{4, -16}, {-196, 12}}); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } // Verifies Reshape operation is correctly evaluated. @@ -330,7 +338,7 @@ TEST_P(HloEvaluatorTest, DoesReshape) { TF_ASSERT_OK_AND_ASSIGN(auto literal, LiteralUtil::CreateRandomLiteral( ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0)); - auto literal_clone = literal->CloneToUnique(); + auto literal_clone = literal.Clone(); HloInstruction* literal_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(literal))); @@ -340,14 +348,13 @@ TEST_P(HloEvaluatorTest, DoesReshape) { HloInstruction::CreateTranspose(shape, literal_instruction, permutation)); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate({}); + Literal result = Evaluate({}); using NativeT = typename primitive_util::PrimitiveTypeToNative::type; - result->EachCell( - [&](tensorflow::gtl::ArraySlice indices, NativeT value) { - std::vector rindexes = Permute(permutation, indices); - EXPECT_NEAR(value, literal_clone->Get(rindexes), 0.031250); - }); + result.EachCell([&](absl::Span indices, NativeT value) { + std::vector rindexes = Permute(permutation, indices); + EXPECT_NEAR(value, literal_clone.Get(rindexes), 0.031250); + }); } // Verifies Broadcast operation is correctly evaluated. @@ -359,12 +366,12 @@ TEST_P(HloEvaluatorTest, DoesBroadcast) { HloInstruction* literal_instruction = b.AddInstruction( HloInstruction::CreateConstant(std::move(input_literal))); b.AddInstruction(HloInstruction::CreateBroadcast( - output_literal->shape(), literal_instruction, {1, 2})); + output_literal.shape(), literal_instruction, {1, 2})); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate({}); + Literal result = Evaluate({}); - EXPECT_TRUE(LiteralTestUtil::Equal(*result, *output_literal)); + EXPECT_TRUE(LiteralTestUtil::Equal(result, output_literal)); } TEST_P(HloEvaluatorTest, DoesBroadcastScalar) { @@ -377,13 +384,13 @@ TEST_P(HloEvaluatorTest, DoesBroadcastScalar) { HloInstruction::CreateConstant(std::move(input_literal))); // Broadcast dimension should be empty in the case of scalars. b.AddInstruction(HloInstruction::CreateBroadcast( - output_literal->shape(), literal_instruction, + output_literal.shape(), literal_instruction, /*broadcast_dimensions=*/{})); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate({}); + Literal result = Evaluate({}); - EXPECT_TRUE(LiteralTestUtil::Equal(*result, *output_literal)); + EXPECT_TRUE(LiteralTestUtil::Equal(result, output_literal)); } TEST_P(HloEvaluatorTest, DoesConcatenateSimple) { @@ -401,11 +408,11 @@ TEST_P(HloEvaluatorTest, DoesConcatenateSimple) { module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); auto expected = LiteralUtil::CreateR2( {{-1, -2}, {100, 200}, {-2, -3}, {-100, -200}}); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, ConcatenateHandlesShapeWithZeroElement) { @@ -423,10 +430,10 @@ TEST_P(HloEvaluatorTest, ConcatenateHandlesShapeWithZeroElement) { module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); auto expected = LiteralUtil::CreateR1({100, 200}); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, ConvertWithSameLayout) { @@ -435,17 +442,17 @@ TEST_P(HloEvaluatorTest, ConvertWithSameLayout) { auto input_literal = LiteralUtil::CreateR2({{1, 2}, {3, 4}, {5, 6}}); auto expected = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}}); - ASSERT_TRUE(LayoutUtil::LayoutsInShapesEqual(input_literal->shape(), - expected->shape())); + ASSERT_TRUE(LayoutUtil::LayoutsInShapesEqual(input_literal.shape(), + expected.shape())); HloInstruction* constant = b.AddInstruction( HloInstruction::CreateConstant(std::move(input_literal))); - b.AddInstruction(HloInstruction::CreateConvert(expected->shape(), constant)); + b.AddInstruction(HloInstruction::CreateConvert(expected.shape(), constant)); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); - EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected)); + EXPECT_TRUE(LiteralTestUtil::Equal(result, expected)); } TEST_P(HloEvaluatorTest, ConvertWithDifferentLayout) { @@ -455,17 +462,17 @@ TEST_P(HloEvaluatorTest, ConvertWithDifferentLayout) { {{1, 2}, {3, 4}, {5, 6}}, LayoutUtil::MakeLayout({0, 1})); auto expected = LiteralUtil::CreateR2WithLayout( {{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}}, LayoutUtil::MakeLayout({1, 0})); - ASSERT_FALSE(LayoutUtil::LayoutsInShapesEqual(input_literal->shape(), - expected->shape())); + ASSERT_FALSE(LayoutUtil::LayoutsInShapesEqual(input_literal.shape(), + expected.shape())); HloInstruction* constant = b.AddInstruction( HloInstruction::CreateConstant(std::move(input_literal))); - b.AddInstruction(HloInstruction::CreateConvert(expected->shape(), constant)); + b.AddInstruction(HloInstruction::CreateConvert(expected.shape(), constant)); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); - EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected)); + EXPECT_TRUE(LiteralTestUtil::Equal(result, expected)); } PaddingConfig CreatePaddingConfig( @@ -498,12 +505,12 @@ TEST_P(HloEvaluatorTest, Pad2DIntegerArrayWithZeroDimension) { shape, operand_instruction, padding_value_instruction, padding_config)); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); auto expected = LiteralUtil::CreateR2( {{10, 10}, {10, 10}, {10, 10}, {10, 10}, {10, 10}}); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, Pad4DFloatArrayWithInteriorPadding) { @@ -525,7 +532,7 @@ TEST_P(HloEvaluatorTest, Pad4DFloatArrayWithInteriorPadding) { shape, input_instruction, pad_instruction, r4_padding_on_dim0_dim1)); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); auto expected_array = absl::make_unique>(8, 5, 1, 1); expected_array->Fill(kPadValue); @@ -538,7 +545,7 @@ TEST_P(HloEvaluatorTest, Pad4DFloatArrayWithInteriorPadding) { auto expected = LiteralUtil::CreateR4FromArray4D(*expected_array); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, NegativePadding2D) { @@ -569,7 +576,7 @@ TEST_P(HloEvaluatorTest, NegativePadding2D) { module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); // f32[1,5] { 7.0, 2.718, 2.718, 2.718, 2.718 } auto expected_array = absl::make_unique>(1, 5); @@ -580,7 +587,7 @@ TEST_P(HloEvaluatorTest, NegativePadding2D) { (*expected_array)(0, 4) = 2.718f; auto expected = LiteralUtil::CreateR2FromArray2D(*expected_array); - EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(0.031250))); + EXPECT_TRUE(LiteralTestUtil::Near(expected, result, ErrorSpec(0.031250))); } TEST_P(HloEvaluatorTest, NegativeAndInteriorPadding2D) { @@ -614,12 +621,12 @@ TEST_P(HloEvaluatorTest, NegativeAndInteriorPadding2D) { module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); auto expected_array = absl::make_unique>(0, 9); auto expected = LiteralUtil::CreateR2FromArray2D(*expected_array); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, DotRank2AndRank1) { @@ -649,10 +656,11 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank1) { dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); b.AddInstruction(HloInstruction::CreateDot(shape, lhs_instruction, - rhs_instruction, dot_dnums)); + rhs_instruction, dot_dnums, + DefaultPrecisionConfig(2))); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); // clang-format off auto expected_array = Array2D({ @@ -664,7 +672,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank1) { // clang-format on auto expected = LiteralUtil::CreateR2FromArray2D(expected_array); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, DotRank1AndRank2) { @@ -694,14 +702,15 @@ TEST_P(HloEvaluatorTest, DotRank1AndRank2) { dot_dnums.add_lhs_contracting_dimensions(0); dot_dnums.add_rhs_contracting_dimensions(0); b.AddInstruction(HloInstruction::CreateDot(shape, lhs_instruction, - rhs_instruction, dot_dnums)); + rhs_instruction, dot_dnums, + DefaultPrecisionConfig(2))); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); auto expected = LiteralUtil::CreateR1({22.f, 28.f}); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, DotRank2AndRank2) { @@ -737,10 +746,11 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank2) { dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); b.AddInstruction(HloInstruction::CreateDot(shape, lhs_instruction, - rhs_instruction, dot_dnums)); + rhs_instruction, dot_dnums, + DefaultPrecisionConfig(2))); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); auto expected_array = Array2D({ {22.f, 28.f}, @@ -750,7 +760,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank2) { }); auto expected = LiteralUtil::CreateR2FromArray2D(expected_array); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, SimpleConv1D) { @@ -788,17 +798,18 @@ TEST_P(HloEvaluatorTest, SimpleConv1D) { dnums.set_kernel_input_feature_dimension(1); dnums.add_kernel_spatial_dimensions(2); - const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 3}); + Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 3}); b.AddInstruction(HloInstruction::CreateConvolve( - shape, lhs_instruction, rhs_instruction, window, dnums)); + shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1, + window, dnums, DefaultPrecisionConfig(2))); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); Array3D expected_array = {{{11.f, 18.f, 9.f}}}; auto expected = LiteralUtil::CreateR3FromArray3D(expected_array); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) { @@ -842,12 +853,13 @@ TEST_P(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) { ConvolutionDimensionNumbers dnums = XlaBuilder::CreateDefaultConvDimensionNumbers(2); - const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4}); + Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4}); b.AddInstruction(HloInstruction::CreateConvolve( - shape, lhs_instruction, rhs_instruction, window, dnums)); + shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1, + window, dnums, DefaultPrecisionConfig(2))); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); Array4D expected_array(1, 1, 4, 4); // clang-format off @@ -860,7 +872,7 @@ TEST_P(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) { // clang-format on auto expected = LiteralUtil::CreateR4FromArray4D(expected_array); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, Conv2DGeneralDimensionsReversed) { @@ -925,22 +937,23 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensionsReversed) { dnums.add_kernel_spatial_dimensions(3); dnums.add_kernel_spatial_dimensions(1); - const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2}); + Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2}); b.AddInstruction(HloInstruction::CreateConvolve( - shape, lhs_instruction, rhs_instruction, window, dnums)); + shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1, + window, dnums, DefaultPrecisionConfig(2))); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); // clang-format off // Result dimensions: [feature=1, height=1, batch=1, width=2] Array4D expected_array({{{{2514, 2685}}}}); - Array4D expected_array_bf16({{{{2512, 2672}}}}); + Array4D expected_array_bf16({{{{2512, 2688}}}}); // clang-format on auto expected = LiteralUtil::CreateR4FromArray4D( use_bfloat16_ ? expected_array_bf16 : expected_array); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, Conv2DGeneralDimensions) { @@ -1002,22 +1015,23 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensions) { dnums.add_kernel_spatial_dimensions(3); dnums.add_kernel_spatial_dimensions(1); - const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2}); + Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2}); b.AddInstruction(HloInstruction::CreateConvolve( - shape, lhs_instruction, rhs_instruction, window, dnums)); + shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1, + window, dnums, DefaultPrecisionConfig(2))); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); // clang-format off // Result dimensions: [feature=1, height=1, batch=1, width=2] Array4D expected_array({{{{2514, 2685}}}}); - Array4D expected_array_bf16({{{{2512, 2672}}}}); + Array4D expected_array_bf16({{{{2512, 2688}}}}); // clang-format on auto expected = LiteralUtil::CreateR4FromArray4D( use_bfloat16_ ? expected_array_bf16 : expected_array); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) { @@ -1061,12 +1075,13 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) { ConvolutionDimensionNumbers dnums = XlaBuilder::CreateDefaultConvDimensionNumbers(2); - const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 7, 7}); + Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 7, 7}); b.AddInstruction(HloInstruction::CreateConvolve( - shape, lhs_instruction, rhs_instruction, window, dnums)); + shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1, + window, dnums, DefaultPrecisionConfig(2))); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); Array4D expected_array(1, 1, 7, 7); expected_array.FillWithYX(Array2D({ @@ -1080,7 +1095,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) { })); auto expected = LiteralUtil::CreateR4FromArray4D(expected_array); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) { @@ -1124,12 +1139,13 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) { ConvolutionDimensionNumbers dnums = XlaBuilder::CreateDefaultConvDimensionNumbers(2); - const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 8, 8}); + Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 8, 8}); b.AddInstruction(HloInstruction::CreateConvolve( - shape, lhs_instruction, rhs_instruction, window, dnums)); + shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1, + window, dnums, DefaultPrecisionConfig(2))); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); Array4D expected_array(1, 1, 8, 8); expected_array.FillWithYX(Array2D({ @@ -1144,7 +1160,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) { })); auto expected = LiteralUtil::CreateR4FromArray4D(expected_array); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, @@ -1195,12 +1211,13 @@ TEST_P(HloEvaluatorTest, ConvolutionDimensionNumbers dnums = XlaBuilder::CreateDefaultConvDimensionNumbers(2); - const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 9, 3}); + Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 9, 3}); b.AddInstruction(HloInstruction::CreateConvolve( - shape, lhs_instruction, rhs_instruction, window, dnums)); + shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1, + window, dnums, DefaultPrecisionConfig(2))); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); Array4D expected_array(1, 1, 9, 3); expected_array.FillWithYX(Array2D({ @@ -1216,15 +1233,71 @@ TEST_P(HloEvaluatorTest, })); auto expected = LiteralUtil::CreateR4FromArray4D(expected_array); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -class HloEvaluatorPreciseReduceTest : public HloVerifiedTestBase { - public: - HloEvaluatorPreciseReduceTest() - : HloVerifiedTestBase(/*layout_sensitive=*/false, - /*allow_mixed_precision=*/false) {} -}; +TEST_P(HloEvaluatorTest, Conv2DGroupedConvolution) { + HloComputation::Builder b(TestName()); + std::vector input_dims = {1, 2, 2, 4}; + std::vector filter_dims = {2, 2, 2, 8}; + Shape input_shape = ShapeUtil::MakeShapeWithType(input_dims); + Shape filter_shape = ShapeUtil::MakeShapeWithType(filter_dims); + // Tensorflow dimension numbers for 2D convolution. + ConvolutionDimensionNumbers dnums; + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); + dnums.set_input_feature_dimension(3); + dnums.set_output_feature_dimension(3); + dnums.add_kernel_spatial_dimensions(0); + dnums.add_kernel_spatial_dimensions(1); + dnums.set_kernel_input_feature_dimension(2); + dnums.set_kernel_output_feature_dimension(3); + + Window window; + WindowDimension dim; + dim.set_size(2); + dim.set_stride(1); + dim.set_padding_low(0); + dim.set_padding_high(0); + dim.set_window_dilation(1); + dim.set_base_dilation(1); + *window.add_dimensions() = dim; + *window.add_dimensions() = dim; + + std::vector input_elems(ShapeUtil::ElementsIn(input_shape)); + std::iota(input_elems.begin(), input_elems.end(), -7); + auto input_r1 = LiteralUtil::CreateR1(input_elems); + auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie(); + HloInstruction* lhs_instruction = + b.AddInstruction(HloInstruction::CreateConstant(std::move(input_r4))); + + std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape)); + std::iota(filter_elems.begin(), filter_elems.end(), -31); + auto filter_r1 = LiteralUtil::CreateR1(filter_elems); + auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie(); + HloInstruction* rhs_instruction = + b.AddInstruction(HloInstruction::CreateConstant(std::move(filter_r4))); + + Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 8}); + b.AddInstruction(HloInstruction::CreateConvolve( + shape, lhs_instruction, rhs_instruction, + /*feature_group_count=*/2, window, dnums, DefaultPrecisionConfig(2))); + module().AddEntryComputation(b.Build()); + + Literal result = Evaluate(); + + Array4D expected_array(1, 1, 1, 8); + expected_array.FillWithYX( + Array2D({{668, 664, 660, 656, 668, 680, 692, 704}})); + auto expected = LiteralUtil::CreateR4FromArray4D(expected_array); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); +} + +class HloEvaluatorPreciseReduceTest : public HloVerifiedTestBase {}; // Tests that Reduce doesn't lose precision when adding many numbers (because // it accumulates its result in a double). @@ -1254,9 +1327,8 @@ TEST_F(HloEvaluatorPreciseReduceTest, AddReductionPrecisionTest) { module().AddEntryComputation(b.Build()); HloEvaluator hlo_eval; - std::unique_ptr result = - hlo_eval.Evaluate(reduce_instruction).ConsumeValueOrDie(); - LiteralTestUtil::ExpectR0Equal(kNumElements, *result); + Literal result = hlo_eval.Evaluate(reduce_instruction).ConsumeValueOrDie(); + LiteralTestUtil::ExpectR0Equal(kNumElements, result); } // Reducing many numbers should be fast because it doesn't create @@ -1333,11 +1405,11 @@ TEST_P(HloEvaluatorTest, ReduceAdd) { module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); auto expected = LiteralUtil::CreateR1({6, 18}); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, ReduceWindowMax) { @@ -1385,10 +1457,62 @@ TEST_P(HloEvaluatorTest, ReduceWindowMax) { module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); auto expected = LiteralUtil::CreateR2({{6, 7}}); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); +} + +TEST_P(HloEvaluatorTest, ReduceWindowMaxWindowDilation) { + HloComputation::Builder b(TestName()); + + // arg: + // f32[3,3] { + // { 1, 2, 3 }, + // { 5, 6, 7 }, + // { 9, 10, 11 }, + // } + auto arg_array = absl::make_unique>(3, 3); + arg_array->FillUnique(1.0f); + auto arg_literal = LiteralUtil::CreateR2FromArray2D(*arg_array); + + HloInstruction* arg_instruction = + b.AddInstruction(HloInstruction::CreateConstant(std::move(arg_literal))); + + auto init_value = b.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.f))); + + HloComputation::Builder max_computation("max"); + Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); + auto param_lhs = max_computation.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "lhs")); + auto param_rhs = max_computation.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape, "rhs")); + max_computation.AddInstruction(HloInstruction::CreateBinary( + scalar_shape, HloOpcode::kMaximum, param_lhs, param_rhs)); + auto max_func = module().AddEmbeddedComputation(max_computation.Build()); + + Window window; + WindowDimension dim; + dim.set_size(2); + dim.set_stride(1); + dim.set_padding_low(0); + dim.set_padding_high(0); + dim.set_window_dilation(2); + dim.set_base_dilation(1); + *window.add_dimensions() = dim; + *window.add_dimensions() = dim; + + Shape shape = ShapeUtil::MakeShape(F32, {1, 1}); + b.AddInstruction(HloInstruction::CreateReduceWindow( + shape, arg_instruction, init_value, window, max_func)); + + module().AddEntryComputation(b.Build()); + + Literal result = Evaluate(); + + auto expected = LiteralUtil::CreateR2({{11}}); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, ReduceWindowAdd) { @@ -1442,10 +1566,10 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd) { module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); auto expected = LiteralUtil::CreateR2({{1, 3, 5}, {5, 11, 13}}); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, ReduceWindowAdd6D) { @@ -1453,7 +1577,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd6D) { // arg: f32[4,4,4,4,4,4] full of ones. Using small dims to limit run-time. std::vector input_dims(6, 4); - std::unique_ptr arg_literal = + Literal arg_literal = LiteralUtil::CreateFullWithDescendingLayout(input_dims, 1.0f); HloInstruction* arg_instruction = @@ -1503,12 +1627,12 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd6D) { module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); std::vector output_dims = {4, 3, 3, 3, 4, 4}; - std::unique_ptr result_literal = + Literal result_literal = LiteralUtil::CreateFullWithDescendingLayout(output_dims, 8.0f); - EXPECT_TRUE(LiteralTestUtil::Equal(*result_literal, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(result_literal, result)); } TEST_P(HloEvaluatorTest, StridedSlice) { @@ -1535,14 +1659,14 @@ TEST_P(HloEvaluatorTest, StridedSlice) { /*strides=*/{2, 3})); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); auto expected = LiteralUtil::CreateR2({ {3}, {19}, }); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, DynamicSlice) { @@ -1569,14 +1693,14 @@ TEST_P(HloEvaluatorTest, DynamicSlice) { start_indices, {2, 3})); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); auto expected = LiteralUtil::CreateR2({ {2, 3, 4}, {6, 7, 8}, }); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } // Verifies that the HloEvaluator's implementation goes along with existing @@ -1605,14 +1729,14 @@ TEST_P(HloEvaluatorTest, DynamicSliceModSlice) { start_indices, {2, 3})); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); auto expected = LiteralUtil::CreateR2({ {2, 3, 4}, {6, 7, 8}, }); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, DynamicSliceUpdate) { @@ -1642,14 +1766,14 @@ TEST_P(HloEvaluatorTest, DynamicSliceUpdate) { shape, operand, update, start_indices)); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); auto expected = LiteralUtil::CreateR2({ {1, -2, -3}, {5, -6, -7}, }); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, SetAndGetTuples) { @@ -1678,14 +1802,14 @@ TEST_P(HloEvaluatorTest, SetAndGetTuples) { module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); auto expected = LiteralUtil::CreateR2({ {1, 2, 3}, {5, 6, 7}, }); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, SetAndGetNestedTuples) { @@ -1717,16 +1841,14 @@ TEST_P(HloEvaluatorTest, SetAndGetNestedTuples) { module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); auto result_inner_literal = LiteralUtil::CreateR2FromArray2D(*operand_array); - auto expected = LiteralUtil::MakeTuple({ - result_inner_literal.get(), - result_inner_literal.get(), - }); + auto expected = + LiteralUtil::MakeTuple({&result_inner_literal, &result_inner_literal}); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, Reverse) { @@ -1757,7 +1879,7 @@ TEST_P(HloEvaluatorTest, Reverse) { b.AddInstruction(HloInstruction::CreateReverse(shape, operand, {0, 1})); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); // clang-format off auto expected = LiteralUtil::CreateR4FromArray4D({ @@ -1779,7 +1901,7 @@ TEST_P(HloEvaluatorTest, Reverse) { }); // clang-format on - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, EvaluateWithSubstitutions) { @@ -1795,12 +1917,13 @@ TEST_P(HloEvaluatorTest, EvaluateWithSubstitutions) { // Evaluate add with param0 = {1, 2, 3, 4}, square = {10, 20, 30, 40}. HloEvaluator evaluator; + Literal param0_literal = LiteralUtil::CreateR1({1, 2, 3, 4}); + Literal square_literal = LiteralUtil::CreateR1({10, 20, 30, 40}); auto result = evaluator.EvaluateWithSubstitutions( - add, {{param0, LiteralUtil::CreateR1({1, 2, 3, 4}).get()}, - {square, LiteralUtil::CreateR1({10, 20, 30, 40}).get()}}); + add, {{param0, ¶m0_literal}, {square, &square_literal}}); TF_ASSERT_OK(result.status()); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::CreateR1({11, 22, 33, 44}), *result.ValueOrDie())); + LiteralUtil::CreateR1({11, 22, 33, 44}), result.ValueOrDie())); } // Check that EvaluateWithSubstitutions works if one of the operands to the op @@ -1820,11 +1943,12 @@ TEST_P(HloEvaluatorTest, EvaluateWithSubstitutionsWithConstantOperand) { // Evaluate add with square = {10, 20, 30, 40}. HloEvaluator evaluator; - auto result = evaluator.EvaluateWithSubstitutions( - add, {{square, LiteralUtil::CreateR1({10, 20, 30, 40}).get()}}); + Literal square_literal = LiteralUtil::CreateR1({10, 20, 30, 40}); + auto result = + evaluator.EvaluateWithSubstitutions(add, {{square, &square_literal}}); TF_ASSERT_OK(result.status()); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::CreateR1({11, 22, 33, 44}), *result.ValueOrDie())); + LiteralUtil::CreateR1({11, 22, 33, 44}), result.ValueOrDie())); } TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherV1) { @@ -1843,12 +1967,12 @@ ENTRY main { } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr start_indices = LiteralUtil::CreateR1({0, 2}); + Literal start_indices = LiteralUtil::CreateR1({0, 2}); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::CreateR2({{1, 2, 3}, {7, 8, 9}}), - *Evaluate({operand.get(), start_indices.get()}))); + LiteralUtil::CreateR2({{1, 2, 3}, {7, 8, 9}}), + Evaluate({&operand, &start_indices}))); } TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherV2) { @@ -1867,12 +1991,12 @@ ENTRY main { } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr start_indices = LiteralUtil::CreateR1({0, 2}); + Literal start_indices = LiteralUtil::CreateR1({0, 2}); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::CreateR2({{1, 3}, {4, 6}, {7, 9}}), - *Evaluate({operand.get(), start_indices.get()}))); + LiteralUtil::CreateR2({{1, 3}, {4, 6}, {7, 9}}), + Evaluate({&operand, &start_indices}))); } TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherMultipleBatchDims) { @@ -1891,14 +2015,13 @@ ENTRY main { } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr start_indices = - LiteralUtil::CreateR2({{0, 2}, {2, 1}}); + Literal start_indices = LiteralUtil::CreateR2({{0, 2}, {2, 1}}); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::CreateR3( + LiteralUtil::CreateR3( {{{1, 3}, {4, 6}, {7, 9}}, {{3, 2}, {6, 5}, {9, 8}}}), - *Evaluate({operand.get(), start_indices.get()}))); + Evaluate({&operand, &start_indices}))); } TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherNd) { @@ -1917,15 +2040,14 @@ ENTRY main { } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // {{-4, 4}, {-5, 5}, {-6, 6}}, // {{-7, 7}, {-8, 8}, {-9, 9}}}); - std::unique_ptr start_indices = - LiteralUtil::CreateR2({{0, 0}, {1, 0}}); + Literal start_indices = LiteralUtil::CreateR2({{0, 0}, {1, 0}}); EXPECT_TRUE( - LiteralTestUtil::Equal(*LiteralUtil::CreateR2({{-1, 1}, {-4, 4}}), - *Evaluate({operand.get(), start_indices.get()}))); + LiteralTestUtil::Equal(LiteralUtil::CreateR2({{-1, 1}, {-4, 4}}), + Evaluate({&operand, &start_indices}))); } TEST_P(HloEvaluatorTest, @@ -1945,15 +2067,14 @@ ENTRY main { } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // {{-4, 4}, {-5, 5}, {-6, 6}}, // {{-7, 7}, {-8, 8}, {-9, 9}}}); - std::unique_ptr start_indices = - LiteralUtil::CreateR2({{0, 0}, {1, 0}}); + Literal start_indices = LiteralUtil::CreateR2({{0, 0}, {1, 0}}); EXPECT_TRUE( - LiteralTestUtil::Equal(*LiteralUtil::CreateR2({{-2, 2}, {-1, 1}}), - *Evaluate({operand.get(), start_indices.get()}))); + LiteralTestUtil::Equal(LiteralUtil::CreateR2({{-2, 2}, {-1, 1}}), + Evaluate({&operand, &start_indices}))); } TEST_P(HloEvaluatorTest, EvaluateGather_DynamicSlice) { @@ -1972,12 +2093,11 @@ ENTRY main { } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr start_indices = LiteralUtil::CreateR1({1, 1}); - EXPECT_TRUE( - LiteralTestUtil::Equal(*LiteralUtil::CreateR2({{5}}), - *Evaluate({operand.get(), start_indices.get()}))); + Literal start_indices = LiteralUtil::CreateR1({1, 1}); + EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR2({{5}}), + Evaluate({&operand, &start_indices}))); } TEST_P(HloEvaluatorTest, EvaluateGather_BatchDynamicSlice) { @@ -1996,13 +2116,12 @@ ENTRY main { } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr start_indices = - LiteralUtil::CreateR2({{2, 1}, {1, 1}}); + Literal start_indices = LiteralUtil::CreateR2({{2, 1}, {1, 1}}); EXPECT_TRUE( - LiteralTestUtil::Equal(*LiteralUtil::CreateR3({{{8}}, {{5}}}), - *Evaluate({operand.get(), start_indices.get()}))); + LiteralTestUtil::Equal(LiteralUtil::CreateR3({{{8}}, {{5}}}), + Evaluate({&operand, &start_indices}))); } TEST_P(HloEvaluatorTest, EvaluateGather_ZeroDimBounds) { @@ -2021,11 +2140,10 @@ ENTRY main { } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr operand = LiteralUtil::CreateR2({{}, {}, {}}); - std::unique_ptr start_indices = LiteralUtil::CreateR1({0, 2}); - EXPECT_TRUE( - LiteralTestUtil::Equal(*LiteralUtil::CreateR2({{}, {}}), - *Evaluate({operand.get(), start_indices.get()}))); + Literal operand = LiteralUtil::CreateR2({{}, {}, {}}); + Literal start_indices = LiteralUtil::CreateR1({0, 2}); + EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR2({{}, {}}), + Evaluate({&operand, &start_indices}))); } TEST_P(HloEvaluatorTest, EvaluateGather_NoOutputWindowDims) { @@ -2045,12 +2163,12 @@ ENTRY main { )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr operand = LiteralUtil::CreateR1({0, 1, 2}); - std::unique_ptr start_indices = + Literal operand = LiteralUtil::CreateR1({0, 1, 2}); + Literal start_indices = LiteralUtil::CreateR3({{{0}, {1}}, {{2}, {1}}}); EXPECT_TRUE( - LiteralTestUtil::Equal(*LiteralUtil::CreateR2({{0, 1}, {2, 1}}), - *Evaluate({operand.get(), start_indices.get()}))); + LiteralTestUtil::Equal(LiteralUtil::CreateR2({{0, 1}, {2, 1}}), + Evaluate({&operand, &start_indices}))); } TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterV1_Update) { @@ -2075,15 +2193,13 @@ ENTRY main { } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr scatter_indices = - LiteralUtil::CreateR1({0, 2}); - std::unique_ptr updates = - LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); + Literal scatter_indices = LiteralUtil::CreateR1({0, 2}); + Literal updates = LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::CreateR2({{10, 20, 30}, {4, 5, 6}, {70, 80, 90}}), - *Evaluate({operand.get(), scatter_indices.get(), updates.get()}))); + LiteralUtil::CreateR2({{10, 20, 30}, {4, 5, 6}, {70, 80, 90}}), + Evaluate({&operand, &scatter_indices, &updates}))); } TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterV2_Update) { @@ -2108,15 +2224,14 @@ ENTRY main { } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr scatter_indices = - LiteralUtil::CreateR1({0, 2}); - std::unique_ptr updates = + Literal scatter_indices = LiteralUtil::CreateR1({0, 2}); + Literal updates = LiteralUtil::CreateR2({{10, 30}, {40, 60}, {70, 90}}); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::CreateR2({{10, 2, 30}, {40, 5, 60}, {70, 8, 90}}), - *Evaluate({operand.get(), scatter_indices.get(), updates.get()}))); + LiteralUtil::CreateR2({{10, 2, 30}, {40, 5, 60}, {70, 8, 90}}), + Evaluate({&operand, &scatter_indices, &updates}))); } TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_Add) { @@ -2142,15 +2257,13 @@ ENTRY main { } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr scatter_indices = - LiteralUtil::CreateR1({0, 2}); - std::unique_ptr updates = - LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); + Literal scatter_indices = LiteralUtil::CreateR1({0, 2}); + Literal updates = LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::CreateR2({{11, 22, 33}, {4, 5, 6}, {77, 88, 99}}), - *Evaluate({operand.get(), scatter_indices.get(), updates.get()}))); + LiteralUtil::CreateR2({{11, 22, 33}, {4, 5, 6}, {77, 88, 99}}), + Evaluate({&operand, &scatter_indices, &updates}))); } TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_Mul) { @@ -2176,15 +2289,13 @@ ENTRY main { } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr scatter_indices = - LiteralUtil::CreateR1({0, 2}); - std::unique_ptr updates = - LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); + Literal scatter_indices = LiteralUtil::CreateR1({0, 2}); + Literal updates = LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::CreateR2({{10, 40, 90}, {4, 5, 6}, {490, 640, 810}}), - *Evaluate({operand.get(), scatter_indices.get(), updates.get()}))); + LiteralUtil::CreateR2({{10, 40, 90}, {4, 5, 6}, {490, 640, 810}}), + Evaluate({&operand, &scatter_indices, &updates}))); } TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_F32) { @@ -2210,17 +2321,15 @@ ENTRY main { } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr operand = LiteralUtil::CreateR2( + Literal operand = LiteralUtil::CreateR2( {{1.1, 2.2, 3.3}, {4.4, 5.5, 6.6}, {7.7, 8.8, 9.9}}); - std::unique_ptr scatter_indices = - LiteralUtil::CreateR1({2, 1}); - std::unique_ptr updates = + Literal scatter_indices = LiteralUtil::CreateR1({2, 1}); + Literal updates = LiteralUtil::CreateR2({{0.4, 1.1, 0.7}, {2.3, 3.1, 1.6}}); EXPECT_TRUE(LiteralTestUtil::Near( - *LiteralUtil::CreateR2( + LiteralUtil::CreateR2( {{1.1, 2.2, 3.3}, {6.7, 8.6, 8.2}, {8.1, 9.9, 10.6}}), - *Evaluate({operand.get(), scatter_indices.get(), updates.get()}), - ErrorSpec{0.1, 0.01})); + Evaluate({&operand, &scatter_indices, &updates}), ErrorSpec{0.1, 0.01})); } TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_RepeatedIndices) { @@ -2246,15 +2355,13 @@ ENTRY main { } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr scatter_indices = - LiteralUtil::CreateR1({1, 1}); - std::unique_ptr updates = - LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); + Literal scatter_indices = LiteralUtil::CreateR1({1, 1}); + Literal updates = LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::CreateR2({{1, 2, 3}, {84, 105, 126}, {7, 8, 9}}), - *Evaluate({operand.get(), scatter_indices.get(), updates.get()}))); + LiteralUtil::CreateR2({{1, 2, 3}, {84, 105, 126}, {7, 8, 9}}), + Evaluate({&operand, &scatter_indices, &updates}))); } TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_MultipleBatchDims) { @@ -2280,15 +2387,14 @@ ENTRY main { } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr scatter_indices = - LiteralUtil::CreateR2({{0, 2}, {2, 1}}); - std::unique_ptr updates = LiteralUtil::CreateR3( + Literal scatter_indices = LiteralUtil::CreateR2({{0, 2}, {2, 1}}); + Literal updates = LiteralUtil::CreateR3( {{{10, 30}, {40, 60}, {70, 90}}, {{5, 5}, {5, 5}, {5, 5}}}); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::CreateR2({{11, 7, 38}, {44, 10, 71}, {77, 13, 104}}), - *Evaluate({operand.get(), scatter_indices.get(), updates.get()}))); + LiteralUtil::CreateR2({{11, 7, 38}, {44, 10, 71}, {77, 13, 104}}), + Evaluate({&operand, &scatter_indices, &updates}))); } TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterNd) { @@ -2313,21 +2419,18 @@ ENTRY main { } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // {{-4, 4}, {-5, 5}, {-6, 6}}, // {{-7, 7}, {-8, 8}, {-9, 9}}}); - std::unique_ptr scatter_indices = - LiteralUtil::CreateR2({{0, 0}, {1, 0}}); - std::unique_ptr updates = - LiteralUtil::CreateR2({{-10, 10}, {-40, 40}}); - std::unique_ptr expected = + Literal scatter_indices = LiteralUtil::CreateR2({{0, 0}, {1, 0}}); + Literal updates = LiteralUtil::CreateR2({{-10, 10}, {-40, 40}}); + Literal expected = LiteralUtil::CreateR3({{{-10, 10}, {-2, 2}, {-3, 3}}, // {{-40, 40}, {-5, 5}, {-6, 6}}, // {{-7, 7}, {-8, 8}, {-9, 9}}}); EXPECT_TRUE(LiteralTestUtil::Equal( - *expected, - *Evaluate({operand.get(), scatter_indices.get(), updates.get()}))); + expected, Evaluate({&operand, &scatter_indices, &updates}))); } TEST_P(HloEvaluatorTest, @@ -2353,21 +2456,18 @@ ENTRY main { } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // {{-4, 4}, {-5, 5}, {-6, 6}}, // {{-7, 7}, {-8, 8}, {-9, 9}}}); - std::unique_ptr scatter_indices = - LiteralUtil::CreateR2({{0, 0}, {1, 0}}); - std::unique_ptr updates = - LiteralUtil::CreateR2({{-10, 10}, {-20, 20}}); - std::unique_ptr expected = + Literal scatter_indices = LiteralUtil::CreateR2({{0, 0}, {1, 0}}); + Literal updates = LiteralUtil::CreateR2({{-10, 10}, {-20, 20}}); + Literal expected = LiteralUtil::CreateR3({{{-20, 20}, {-10, 10}, {-3, 3}}, // {{-4, 4}, {-5, 5}, {-6, 6}}, // {{-7, 7}, {-8, 8}, {-9, 9}}}); EXPECT_TRUE(LiteralTestUtil::Equal( - *expected, - *Evaluate({operand.get(), scatter_indices.get(), updates.get()}))); + expected, Evaluate({&operand, &scatter_indices, &updates}))); } TEST_P(HloEvaluatorTest, EvaluateScatter_DynamicUpdateSlice) { @@ -2392,16 +2492,14 @@ ENTRY main { } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr scatter_indices = - LiteralUtil::CreateR1({1, 1}); - std::unique_ptr updates = LiteralUtil::CreateR2({{10}}); - std::unique_ptr expected = + Literal scatter_indices = LiteralUtil::CreateR1({1, 1}); + Literal updates = LiteralUtil::CreateR2({{10}}); + Literal expected = LiteralUtil::CreateR2({{1, 2, 3}, {4, 10, 6}, {7, 8, 9}}); EXPECT_TRUE(LiteralTestUtil::Equal( - *expected, - *Evaluate({operand.get(), scatter_indices.get(), updates.get()}))); + expected, Evaluate({&operand, &scatter_indices, &updates}))); } TEST_P(HloEvaluatorTest, EvaluateScatter_BatchDynamicUpdateSlice) { @@ -2426,17 +2524,14 @@ ENTRY main { } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr scatter_indices = - LiteralUtil::CreateR2({{2, 1}, {1, 1}}); - std::unique_ptr updates = - LiteralUtil::CreateR3({{{10}}, {{20}}}); - std::unique_ptr expected = + Literal scatter_indices = LiteralUtil::CreateR2({{2, 1}, {1, 1}}); + Literal updates = LiteralUtil::CreateR3({{{10}}, {{20}}}); + Literal expected = LiteralUtil::CreateR2({{1, 2, 3}, {4, 20, 6}, {7, 10, 9}}); EXPECT_TRUE(LiteralTestUtil::Equal( - *expected, - *Evaluate({operand.get(), scatter_indices.get(), updates.get()}))); + expected, Evaluate({&operand, &scatter_indices, &updates}))); } TEST_P(HloEvaluatorTest, EvaluateScatter_ZeroDimBounds) { @@ -2461,13 +2556,11 @@ ENTRY main { } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr operand = LiteralUtil::CreateR2({{}, {}, {}}); - std::unique_ptr scatter_indices = - LiteralUtil::CreateR1({0, 2}); - std::unique_ptr updates = LiteralUtil::CreateR2({{}, {}}); + Literal operand = LiteralUtil::CreateR2({{}, {}, {}}); + Literal scatter_indices = LiteralUtil::CreateR1({0, 2}); + Literal updates = LiteralUtil::CreateR2({{}, {}}); EXPECT_TRUE(LiteralTestUtil::Equal( - *operand, - *Evaluate({operand.get(), scatter_indices.get(), updates.get()}))); + operand, Evaluate({&operand, &scatter_indices, &updates}))); } TEST_P(HloEvaluatorTest, EvaluateScatter_NoUpdateWindowDims) { @@ -2494,16 +2587,121 @@ ENTRY main { )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr operand = LiteralUtil::CreateR1({0, 1, 2}); - std::unique_ptr scatter_indices = + Literal operand = LiteralUtil::CreateR1({0, 1, 2}); + Literal scatter_indices = LiteralUtil::CreateR3({{{0}, {1}}, {{2}, {1}}}); - std::unique_ptr updates = - LiteralUtil::CreateR2({{10, 20}, {30, 40}}); - std::unique_ptr expected = - LiteralUtil::CreateR1({10, 61, 32}); + Literal updates = LiteralUtil::CreateR2({{10, 20}, {30, 40}}); + Literal expected = LiteralUtil::CreateR1({10, 61, 32}); + EXPECT_TRUE(LiteralTestUtil::Equal( + expected, Evaluate({&operand, &scatter_indices, &updates}))); +} + +TEST_P(HloEvaluatorTest, EvaluateScatter_NegativeIndices) { + const char* hlo_text = R"( +HloModule TensorFlowScatter_NegativeIndices + +add_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + rhs = s32[] parameter(1) + ROOT add = s32[] add(s32[] lhs, s32[] rhs) +} + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + updates = s32[2,3] parameter(2) + ROOT scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=add_s32, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_text)); + Literal operand = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + // No updates should happen for the negative indices. + Literal scatter_indices = LiteralUtil::CreateR1({-1, 2}); + Literal updates = LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); + EXPECT_TRUE(LiteralTestUtil::Equal( + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {77, 88, 99}}), + EvaluateWithModule(module.get(), + {&operand, &scatter_indices, &updates}))); +} + +TEST_P(HloEvaluatorTest, EvaluateScatter_OobIndices) { + const string hlo_text = R"( +HloModule BatchDynamicUpdateSlice + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[3,3]{1,0} parameter(0) + indices = s32[6,2]{1,0} parameter(1) + updates = s32[6,1,1]{2,1,0} parameter(2) + ROOT scatter = s32[3,3]{1,0} scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={1,2}, + inserted_window_dims={}, + scatter_dims_to_operand_dims={0,1}, + index_vector_dim=1 +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_text)); + Literal operand = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + // No updates should happen for the OOB indices. + Literal scatter_indices = LiteralUtil::CreateR2( + {{2, 7}, {2, 1}, {1, 1}, {5, 1}, {2147483647, 1}, {1, 2}}); + Literal updates = LiteralUtil::CreateR3( + {{{10}}, {{20}}, {{30}}, {{40}}, {{50}}, {{60}}}); + EXPECT_TRUE(LiteralTestUtil::Equal( + LiteralUtil::CreateR2({{1, 2, 3}, {4, 30, 60}, {7, 20, 9}}), + EvaluateWithModule(module.get(), + {&operand, &scatter_indices, &updates}))); +} + +TEST_P(HloEvaluatorTest, EvaluateScatter_OobUpdateWindow) { + const char* hlo_text = R"( +HloModule TensorFlowScatterNd_OobUpdateWindow + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[3,3,2] parameter(0) + indices = s32[1,2] parameter(1) + updates = s32[1,2,2] parameter(2) + ROOT scatter = s32[3,3,2] scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={1,2}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0,1}, + index_vector_dim=1 +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_text)); + Literal operand = + LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // + {{-4, 4}, {-5, 5}, {-6, 6}}, // + {{-7, 7}, {-8, 8}, {-9, 9}}}); + Literal scatter_indices = LiteralUtil::CreateR2({{0, 2}}); + Literal updates = LiteralUtil::CreateR3({{{-10, 10}, {-40, 40}}}); + // Given the update window size of 2,2 and the index of 0,2, the update window + // will be OOB. So, nothing should be updated. + Literal expected = operand.Clone(); EXPECT_TRUE(LiteralTestUtil::Equal( - *expected, - *Evaluate({operand.get(), scatter_indices.get(), updates.get()}))); + expected, EvaluateWithModule(module.get(), + {&operand, &scatter_indices, &updates}))); } // Verifies that HloEvaluator evaluates a HLO instruction that performs @@ -2540,11 +2738,29 @@ ENTRY main { )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr arg = LiteralUtil::CreateR1( + Literal arg = LiteralUtil::CreateR1( {bfloat16(1.0f), bfloat16(3.0f), bfloat16(-2.0f), bfloat16(42.0f)}); - std::unique_ptr expected = - LiteralUtil::CreateR0(bfloat16(44.0f)); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *Evaluate({arg.get()}))); + Literal expected = LiteralUtil::CreateR0(bfloat16(44.0f)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, Evaluate({&arg}))); +} + +TEST_P(HloEvaluatorTest, SliceWithDifferentLayout) { + // Regression test for b/114735354. + const string hlo_text = R"( +HloModule SliceWithDifferentLayout + +ENTRY main { + arg = f32[2,2,2]{0,1,2} parameter(0) + ROOT %slice = f32[2,2,2]{1,0,2} slice(f32[2,2,2]{0,1,2} %arg), slice={[0:2], [0:2], [0:2]} +} +)"; + ParseAndVerifyModule(hlo_text); + + Literal arg = LiteralUtil::CreateR3WithLayout( + {{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}}, + LayoutUtil::MakeLayout({0, 1, 2})); + Literal actual = Evaluate({&arg}); + EXPECT_TRUE(LiteralTestUtil::Equal(arg, actual)); } INSTANTIATE_TEST_CASE_P(HloEvaluatorTest_Instantiation, HloEvaluatorTest, diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h index 2da2cc2d71ed94315cfc15a737155b65f9e8f7ad..84fbbd3e0c3ddb704b8db601897f3b199dc99626 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h @@ -16,12 +16,16 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_TYPED_VISITOR_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_TYPED_VISITOR_H_ +#include + #include "absl/algorithm/container.h" #include "absl/container/inlined_vector.h" #include "absl/memory/memory.h" #include "absl/types/optional.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_evaluator.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/shape_inference.h" #include "tensorflow/core/lib/core/casts.h" @@ -39,7 +43,9 @@ template using is_complex64_t = std::is_same; // It's UB to use std::sort with std::less, because of NaNs. Define -// "safe" less functions which are actually strict weak orders. +// "safe" less functions which are actually strict weak orders. -NaN and NaN +// should appear at the beginning and end of the ordering, and -0.0 should +// appear before 0.0. template < typename NativeT, typename std::enable_if::value>::type* = nullptr> @@ -47,26 +53,33 @@ bool SafeLess(const NativeT& a, const NativeT& b) { return a < b; } -template ::value || - std::is_same::value>::type* = nullptr> +template ::value>::type* = nullptr> bool SafeLess(const NativeT& a, const NativeT& b) { - if (std::isnan(b)) { - return !std::isnan(a); - } else { - return a < b; + bool lhs_is_negative = std::signbit(a); + bool rhs_is_negative = std::signbit(b); + // If the signs are different, we can just compare the signs. + if (lhs_is_negative != rhs_is_negative) { + return lhs_is_negative && !rhs_is_negative; + } + bool lhs_nan = std::isnan(a); + bool rhs_nan = std::isnan(b); + // Exactly one number is nan? + if (lhs_nan != rhs_nan) { + if (lhs_nan) { + return lhs_is_negative; + } + return !rhs_is_negative; } + return a < b; } -template ::value>::type* = nullptr> +template ::value || + std::is_same::value>::type* = nullptr> bool SafeLess(const NativeT& a, const NativeT& b) { - if (Eigen::half_impl::isnan(b)) { - return !Eigen::half_impl::isnan(a); - } else { - return a < b; - } + return SafeLess(static_cast(a), static_cast(b)); } // Templated DfsHloVisitor for use by HloEvaluator. @@ -76,6 +89,8 @@ bool SafeLess(const NativeT& a, const NativeT& b) { // to this rule, notably: // - HandleCompare and HandleIsFinite: where the resulting literal type is // always boolean. +// - HandleImag and HandleReal: where the resulting literal type is always float +// and the operand is always complex, or real in the case of HandleReal. // These operations are handled outside of the parent HloEvaluator handlers // instead of from within TypedVisitor. // @@ -95,7 +110,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { typename NativeT, typename std::enable_if::value>::type* = nullptr> double GetAsDouble(const Literal& literal, - tensorflow::gtl::ArraySlice input_index) { + absl::Span input_index) { return static_cast(literal.Get(input_index)); } @@ -107,7 +122,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { typename NativeT, typename std::enable_if::value>::type* = nullptr> double GetAsDouble(const Literal& literal, - tensorflow::gtl::ArraySlice input_index) { + absl::Span input_index) { LOG(FATAL) << "Trying to get complex literal as double: " << literal.ToString(); } @@ -143,7 +158,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { Status DefaultAction(HloInstruction* hlo_instruction) override { return Unimplemented("unhandled HLO ops for HloEvaluator: %s.", - HloOpcodeString(hlo_instruction->opcode()).c_str()); + HloOpcodeString(hlo_instruction->opcode())); } // TODO(b/35950897): many of the stl functions used in the handlers are not @@ -244,32 +259,21 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { Status HandleConvert(HloInstruction* convert) override { const HloInstruction* operand = convert->operand(0); TF_RET_CHECK(ShapeUtil::SameDimensions(operand->shape(), convert->shape())); - TF_ASSIGN_OR_RETURN(std::unique_ptr result, + TF_ASSIGN_OR_RETURN(Literal result, parent_->GetEvaluatedLiteralFor(operand).Convert( convert->shape().element_type())); - - if (LayoutUtil::LayoutsInShapesEqual(result->shape(), convert->shape())) { - parent_->evaluated_[convert] = std::move(result); - } else { - parent_->evaluated_[convert] = - result->Relayout(convert->shape().layout()); - } + parent_->evaluated_[convert] = std::move(result); return Status::OK(); } Status HandleBitcastConvert(HloInstruction* convert) override { const HloInstruction* operand = convert->operand(0); TF_RET_CHECK(ShapeUtil::SameDimensions(operand->shape(), convert->shape())); - TF_ASSIGN_OR_RETURN(std::unique_ptr result, + TF_ASSIGN_OR_RETURN(Literal result, parent_->GetEvaluatedLiteralFor(operand).BitcastConvert( convert->shape().element_type())); - if (LayoutUtil::LayoutsInShapesEqual(result->shape(), convert->shape())) { - parent_->evaluated_[convert] = std::move(result); - } else { - parent_->evaluated_[convert] = - result->Relayout(convert->shape().layout()); - } + parent_->evaluated_[convert] = std::move(result); return Status::OK(); } @@ -327,14 +331,6 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return HandleFloor(floor); } - Status HandleImag(HloInstruction* imag) override { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[imag], - ElementWiseUnaryOp(imag, [](ElementwiseT elem_operand) { - return std::imag(elem_operand); - })); - return Status::OK(); - } - Status HandleLog(HloInstruction* log) override { TF_ASSIGN_OR_RETURN(parent_->evaluated_[log], ElementWiseUnaryOp(log, [](ElementwiseT elem_operand) { @@ -682,14 +678,6 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); } - Status HandleReal(HloInstruction* real) override { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[real], - ElementWiseUnaryOp(real, [](ElementwiseT elem_operand) { - return std::real(elem_operand); - })); - return Status::OK(); - } - template ::value>::type* = nullptr> Status HandleRemainder(HloInstruction* remainder) { @@ -976,10 +964,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { << ShapeUtil::HumanString(inferred_return_shape); const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); - auto result = absl::make_unique(result_shape); + Literal result(result_shape); - TF_RETURN_IF_ERROR(result->Populate( - [&](tensorflow::gtl::ArraySlice out_index) { + TF_RETURN_IF_ERROR( + result.Populate([&](absl::Span out_index) { std::vector from_index(out_index.begin(), out_index.end()); for (const int64 dim : reverse_dimensions) { from_index[dim] = result_shape.dimensions(dim) - 1 - out_index[dim]; @@ -1019,9 +1007,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { CHECK_EQ(num_spatial_dims + 2, lhs_rank); CHECK_EQ(num_spatial_dims + 2, rhs_rank); - TF_ASSIGN_OR_RETURN(auto inferred_return_shape, - ShapeInference::InferConvolveShape(lhs_shape, rhs_shape, - window, dnums)); + TF_ASSIGN_OR_RETURN( + auto inferred_return_shape, + ShapeInference::InferConvolveShape( + lhs_shape, rhs_shape, conv->feature_group_count(), window, dnums)); CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape)) << "return shape set to: " << ShapeUtil::HumanString(result_shape) << " but is inferred to be: " @@ -1044,10 +1033,12 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { auto lhs_literal_data = lhs_literal.data(); auto rhs_literal_data = rhs_literal.data(); + int64 feature_group_count = conv->feature_group_count(); + auto func = [&window_shape, &dnums, &lhs_shape, &rhs_shape, &window, &lhs_dim_multipliers, &rhs_dim_multipliers, lhs_literal_data, - rhs_literal_data]( - tensorflow::gtl::ArraySlice out_index) { + rhs_literal_data, + feature_group_count](const absl::Span out_index) { // Dimension number applicable for input (lhs). const int64 input_batch_dim = dnums.input_batch_dimension(); const int64 input_z_dim = dnums.input_feature_dimension(); @@ -1058,7 +1049,22 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { const int64 output_batch_dim = dnums.output_batch_dimension(); const int64 output_z_dim = dnums.output_feature_dimension(); - const int64 z_size = ShapeUtil::GetDimension(lhs_shape, input_z_dim); + const int64 input_z_size = + ShapeUtil::GetDimension(lhs_shape, input_z_dim); + // The size of an input feature group. + const int64 input_feature_group_size = input_z_size / feature_group_count; + + const int64 output_z_size = + ShapeUtil::GetDimension(rhs_shape, kernel_output_z_dim); + // The output feature dimension is a concatenation of convolution results + // from the different groups. + const int64 output_feature_group_size = + output_z_size / feature_group_count; + + // Calculate the group index to which the current output index + // belongs. + const int64 feature_group_index = + out_index[output_z_dim] / output_feature_group_size; ElementwiseT result_val = static_cast(0); DimensionVector rhs_spatial_index(dnums.kernel_spatial_dimensions_size(), @@ -1066,75 +1072,79 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { // Convolve input feature with kernel. do { - for (int64 iz = 0; iz < z_size; ++iz) { - int64 lhs_linear_index = 0; + // Find corresponding spatial dimension index for input (lhs). + int64 lhs_linear_spatial_index = 0; + int64 rhs_linear_spatial_index = 0; + for (int64 ki = 0; ki < rhs_spatial_index.size(); ++ki) { + // Spatial dimension number for input (lhs) and output. + const int64 input_spatial_dim = dnums.input_spatial_dimensions(ki); + const int64 output_spatial_dim = dnums.output_spatial_dimensions(ki); + + // Calculate lhs (input) index without taking base dilation into + // account. + const auto& window_dim = window.dimensions(ki); + const int64 undilated_index = + out_index[output_spatial_dim] * window_dim.stride() - + window_dim.padding_low() + + rhs_spatial_index[ki] * window_dim.window_dilation(); + // Skip if the lhs (input) index is to be dilated. As an + // optimization, skip this mod if there's no dilation. + if (window_dim.base_dilation() > 1 && + undilated_index % window_dim.base_dilation() != 0) { + goto cnt; + } + + // Calculate the actual lhs (input) index after dilation. As an + // optimization, skip this integer divide if there's no dilation. + int64 lhs_spatial_index; + if (window_dim.base_dilation() > 1) { + lhs_spatial_index = undilated_index / window_dim.base_dilation(); + } else { + lhs_spatial_index = undilated_index; + } + + // Skip if input index is not in bounds. + if (!(lhs_spatial_index >= 0 && + lhs_spatial_index < lhs_shape.dimensions(input_spatial_dim))) { + goto cnt; + } + + lhs_linear_spatial_index += + lhs_spatial_index * lhs_dim_multipliers[input_spatial_dim]; + rhs_linear_spatial_index += + (window_dim.window_reversal() + ? ((window_dim.size() - 1) - rhs_spatial_index[ki]) + : rhs_spatial_index[ki]) * + rhs_dim_multipliers[dnums.kernel_spatial_dimensions(ki)]; + } + + for (int64 rhs_iz = 0; rhs_iz < input_feature_group_size; ++rhs_iz) { + const int64 iz = + feature_group_index * input_feature_group_size + rhs_iz; + + int64 lhs_linear_index = lhs_linear_spatial_index; lhs_linear_index += out_index[output_batch_dim] * lhs_dim_multipliers[input_batch_dim]; lhs_linear_index += iz * lhs_dim_multipliers[input_z_dim]; - int64 rhs_linear_index = 0; + int64 rhs_linear_index = rhs_linear_spatial_index; rhs_linear_index += out_index[output_z_dim] * rhs_dim_multipliers[kernel_output_z_dim]; - rhs_linear_index += iz * rhs_dim_multipliers[kernel_input_z_dim]; - - // Find corresponding spatial dimension index for input (lhs). - for (int64 ki = 0; ki < rhs_spatial_index.size(); ++ki) { - // Spatial dimension number for input (lhs) and output. - const int64 input_spatial_dim = dnums.input_spatial_dimensions(ki); - const int64 output_spatial_dim = - dnums.output_spatial_dimensions(ki); - - // Calculate lhs (input) index without taking base dilation into - // account. - const auto& window_dim = window.dimensions(ki); - const int64 undilated_index = - out_index[output_spatial_dim] * window_dim.stride() - - window_dim.padding_low() + - rhs_spatial_index[ki] * window_dim.window_dilation(); - // Skip if the lhs (input) index is to be dilated. As an - // optimization, skip this mod if there's no dilation. - if (window_dim.base_dilation() > 1 && - undilated_index % window_dim.base_dilation() != 0) { - goto cnt; - } - - // Calculate the actual lhs (input) index after dilation. As an - // optimization, skip this integer divide if there's no dilation. - int64 lhs_spatial_index; - if (window_dim.base_dilation() > 1) { - lhs_spatial_index = undilated_index / window_dim.base_dilation(); - } else { - lhs_spatial_index = undilated_index; - } - lhs_linear_index += - lhs_spatial_index * lhs_dim_multipliers[input_spatial_dim]; - - // Skip if input index is not in bounds. - if (!(lhs_spatial_index >= 0 && - lhs_spatial_index < - lhs_shape.dimensions(input_spatial_dim))) { - goto cnt; - } - - rhs_linear_index += - (window_dim.window_reversal() - ? ((window_dim.size() - 1) - rhs_spatial_index[ki]) - : rhs_spatial_index[ki]) * - rhs_dim_multipliers[dnums.kernel_spatial_dimensions(ki)]; - } + rhs_linear_index += rhs_iz * rhs_dim_multipliers[kernel_input_z_dim]; result_val += static_cast(lhs_literal_data[lhs_linear_index]) * static_cast(rhs_literal_data[rhs_linear_index]); } cnt : {} - } while (IndexUtil::BumpIndices(window_shape, &rhs_spatial_index)); + } while (IndexUtil::BumpIndices(window_shape, + absl::MakeSpan(rhs_spatial_index))); return static_cast(result_val); }; - auto result = absl::make_unique(result_shape); - TF_RETURN_IF_ERROR(result->PopulateParallel(func)); + Literal result(result_shape); + TF_RETURN_IF_ERROR(result.PopulateParallel(func)); parent_->evaluated_[conv] = std::move(result); return Status::OK(); @@ -1196,20 +1206,20 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { // Then we have the LHS and RHS non-contracting dimensions, if any: for (int64 i = 0; i < lhs_rank; i++) { if (i != lhs_contracting_dimension && - !ArrayContains(AsInt64Slice(dnums.lhs_batch_dimensions()), i)) { + !absl::c_linear_search(dnums.lhs_batch_dimensions(), i)) { result_index_locations.push_back({&lhs_index[i], nullptr}); } } for (int64 i = 0; i < rhs_rank; i++) { if (i != rhs_contracting_dimension && - !ArrayContains(AsInt64Slice(dnums.rhs_batch_dimensions()), i)) { + !absl::c_linear_search(dnums.rhs_batch_dimensions(), i)) { result_index_locations.push_back({&rhs_index[i], nullptr}); } } - auto result = absl::make_unique(dot->shape()); - TF_RETURN_IF_ERROR(result->Populate( - [&](tensorflow::gtl::ArraySlice result_index) { + Literal result(dot->shape()); + TF_RETURN_IF_ERROR( + result.Populate([&](absl::Span result_index) { ElementwiseT result_val = static_cast(0); for (int64 i = 0; i < result_index.size(); i++) { @@ -1256,24 +1266,22 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { // Create new HLO of padded shape with padding value. ReturnT scalar = parent_->GetEvaluatedLiteralFor(pad->operand(1)).Get({}); - auto result = absl::make_unique(pad->shape()); - TF_RETURN_IF_ERROR(result->Populate( - [&scalar](tensorflow::gtl::ArraySlice multi_index) { - return scalar; - })); + Literal result(pad->shape()); + TF_RETURN_IF_ERROR(result.Populate( + [&scalar](absl::Span multi_index) { return scalar; })); const Literal& evaluated_operand = parent_->GetEvaluatedLiteralFor(pad->operand(0)); std::vector input_index(ShapeUtil::Rank(evaluated_operand.shape()), 0); - std::vector target_index(ShapeUtil::Rank(result->shape()), 0); + std::vector target_index(ShapeUtil::Rank(result.shape()), 0); // Loop through each element of the operand, assign them to the // corresponding index of the resulting padded literal. const PaddingConfig& pad_config = pad->padding_config(); - auto func = [&](tensorflow::gtl::ArraySlice input_index) { + auto func = [&](absl::Span input_index) { for (auto i = 0; i < input_index.size(); ++i) { // Interior padding occurs logically before edge padding, so in the case // of negative edge padding elements are removed from the @@ -1289,8 +1297,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return true; } } - result->Set(target_index, - evaluated_operand.Get(input_index)); + result.Set(target_index, + evaluated_operand.Get(input_index)); return true; }; @@ -1417,16 +1425,16 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } template - StatusOr> MapImpl(HloInstruction* map) { + StatusOr MapImpl(HloInstruction* map) { auto operands = map->operands(); HloComputation* computation = map->to_apply(); - auto result = absl::make_unique(map->shape()); + Literal result(map->shape()); HloEvaluator embedded_evaluator(parent_->max_loop_iterations_); - TF_RETURN_IF_ERROR(result->Populate( - [&](tensorflow::gtl::ArraySlice multi_index) { - std::vector> arg_literals; + TF_RETURN_IF_ERROR( + result.Populate([&](absl::Span multi_index) { + std::vector arg_literals; arg_literals.reserve(operands.size()); // Construct scalar literal parameters to be passed to the map @@ -1441,16 +1449,14 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { arg_literals.push_back(std::move(curr_val_literal)); } - std::unique_ptr computed_result = - embedded_evaluator - .Evaluate>(*computation, - arg_literals) + Literal computed_result = + embedded_evaluator.Evaluate(*computation, arg_literals) .ConsumeValueOrDie(); // Clear visit states so that the we can use the evaluate again on // the same computation. embedded_evaluator.ResetVisitStates(); - return computed_result->Get({}); + return computed_result.Get({}); })); return std::move(result); } @@ -1518,48 +1524,55 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { !std::is_same::value>::type* = nullptr> Status HandleSort(HloInstruction* sort) { auto keys = sort->operand(0); - auto rank = ShapeUtil::Rank(keys->shape()); - TF_RET_CHECK(rank > 0 && rank <= 2) - << "Sort is only supported for R1 and R2 shapes"; TF_RET_CHECK(sort->operand_count() == 1) << "Typed visitor does not support key-value sort"; const Literal& keys_literal = parent_->GetEvaluatedLiteralFor(keys); - - auto sort_r1 = [this](const Literal& keys_literal) { - VLOG(3) << "HandleSort keys_literal: " << keys_literal.ToString(); - const auto& keys_data = keys_literal.data(); - - std::vector result_data(keys_data.begin(), keys_data.end()); - std::sort(result_data.begin(), result_data.end(), - [](const ReturnT& a, const ReturnT& b) { - return SafeLess(a, b); - }); - auto result_literal = absl::make_unique(keys_literal.shape()); - result_literal->PopulateR1( - tensorflow::gtl::ArraySlice(result_data)); - VLOG(3) << "HandleSort result_literal: " << result_literal->ToString(); - return result_literal; - }; - - if (rank == 1) { - parent_->evaluated_[sort] = std::move(sort_r1(keys_literal)); - } else { - // For R2 sort, the desired semantics are to sort each matrix row - // independently. - auto result_literal = absl::make_unique(keys_literal.shape()); - int64 r1_length = keys->shape().dimensions(1); - for (int64 row = 0; row < keys->shape().dimensions(0); ++row) { - TF_ASSIGN_OR_RETURN(auto r1_slice, - keys_literal.Slice({row, 0}, {row + 1, r1_length}) - ->Reshape({r1_length})); - auto r1_result = sort_r1(*r1_slice); - TF_ASSIGN_OR_RETURN(r1_result, r1_result->Reshape({1, r1_length})); - TF_RETURN_IF_ERROR(result_literal->CopySliceFrom( - *r1_result, {0, 0}, {row, 0}, {1, r1_length})); - } - parent_->evaluated_[sort] = std::move(result_literal); + int64 sort_dim = sort->dimensions(0); + int64 sort_dim_elements = keys->shape().dimensions(sort_dim); + int64 rank = ShapeUtil::Rank(keys->shape()); + if (rank == 0) { + // Nothing to sort. + parent_->evaluated_[sort] = keys_literal.Clone(); + return Status::OK(); } + Literal result_literal(keys_literal.shape()); + std::vector zero_base(rank, 0); + std::vector increment(rank, 1); + increment[sort_dim] = sort_dim_elements; + // Iterate through each dimension except 'sort_dim'. + TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus( + keys->shape(), zero_base, AsInt64Slice(keys->shape().dimensions()), + increment, [&](absl::Span indices) -> StatusOr { + // Extract a slice from the literal that corresponds to exactly the + // row in dimension 'sort_dim'. + std::vector limit_indices(indices.begin(), indices.end()); + std::for_each(limit_indices.begin(), limit_indices.end(), + [](int64& index) { ++index; }); + limit_indices[sort_dim] = sort_dim_elements; + TF_ASSIGN_OR_RETURN(auto row_to_sort, + keys_literal.Slice(indices, limit_indices) + .Reshape({sort_dim_elements})); + const auto& row_data = row_to_sort.data(); + + std::vector result_data(row_data.begin(), row_data.end()); + std::sort(result_data.begin(), result_data.end(), + [](const NativeT& a, const NativeT& b) { + return SafeLess(a, b); + }); + Literal sorted_row(ShapeUtil::MakeShape(keys->shape().element_type(), + {sort_dim_elements})); + sorted_row.PopulateR1(absl::Span(result_data)); + std::vector slice_dimensions(rank, 1); + slice_dimensions[sort_dim] = sort_dim_elements; + TF_ASSIGN_OR_RETURN(auto sorted_row_reshaped, + sorted_row.Reshape(slice_dimensions)); + std::vector start_indices(rank, 0); + TF_RETURN_IF_ERROR(result_literal.CopySliceFrom( + sorted_row_reshaped, start_indices, indices, slice_dimensions)); + return true; + })); + parent_->evaluated_[sort] = std::move(result_literal); return Status::OK(); } @@ -1575,20 +1588,20 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return HandleSort(sort); } - Status HandleReduce(HloInstruction* reduce) override { - // TODO(b/112040122): Support variadic reduce. - if (!ShapeUtil::IsArray(reduce->shape())) { - return Unimplemented("Variadic reduce is not supported in the Evaluator"); - } - auto arg = reduce->operand(0); - auto init_value = reduce->operand(1); - tensorflow::gtl::ArraySlice dimensions(reduce->dimensions()); + Status HandleReduce(HloInstruction* hlo) override { + HloReduceInstruction* reduce = Cast(hlo); + int64 num_args = reduce->inputs().size(); + bool has_tuple_output = ShapeUtil::IsTuple(reduce->shape()); + absl::Span dimensions(reduce->dimensions()); HloComputation* function = reduce->to_apply(); - TF_RET_CHECK(ShapeUtil::Rank(reduce->shape()) == - ShapeUtil::Rank(arg->shape()) - dimensions.size()); + + absl::InlinedVector operand_shapes; + for (const HloInstruction* operand : reduce->operands()) { + operand_shapes.push_back(&operand->shape()); + } TF_ASSIGN_OR_RETURN(auto inferred_return_shape, ShapeInference::InferReduceShape( - {&arg->shape(), &init_value->shape()}, + operand_shapes, /*dimensions_to_reduce=*/dimensions, /*to_apply=*/function->ComputeProgramShape())); TF_RET_CHECK(ShapeUtil::Compatible(reduce->shape(), inferred_return_shape)) @@ -1596,14 +1609,23 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { << " but is inferred to be: " << ShapeUtil::HumanString(inferred_return_shape); - const Literal& arg_literal = parent_->GetEvaluatedLiteralFor(arg); - VLOG(3) << "HandleReduce arg_literal: " << arg_literal.ToString(); - const Literal& init_literal = parent_->GetEvaluatedLiteralFor(init_value); - VLOG(3) << "HandleReduce init_literal: " << init_literal.ToString(); - TF_RET_CHECK(ShapeUtil::IsScalar(init_literal.shape())); - auto init_scalar = init_literal.Get({}); + absl::InlinedVector arg_literals(num_args); + absl::InlinedVector init_literals(num_args); + for (int64 i = 0; i < num_args; ++i) { + arg_literals[i] = &parent_->GetEvaluatedLiteralFor(reduce->inputs()[i]); + VLOG(3) << "HandleReduce arg_literal: " << arg_literals[i]->ToString(); + init_literals[i] = + &parent_->GetEvaluatedLiteralFor(reduce->init_values()[i]); + VLOG(3) << "HandleReduce init_literal: " << init_literals[i]->ToString(); + TF_RET_CHECK(ShapeUtil::IsScalar(init_literals[i]->shape())); + } - const auto arg_dimensions = AsInt64Slice(arg_literal.shape().dimensions()); + // All args and results have the same dimensions, so pick an arbitrary one. + const Shape& arg_shape = arg_literals[0]->shape(); + const Shape& result_shape = ShapeUtil::IsTuple(reduce->shape()) + ? reduce->shape().tuple_shapes(0) + : reduce->shape(); + const auto arg_dimensions = AsInt64Slice(arg_shape.dimensions()); std::vector arg_dim_steps(arg_dimensions.size()); std::vector arg_dim_counts(arg_dimensions.size()); for (const int64 dim : dimensions) { @@ -1621,63 +1643,106 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } HloEvaluator embedded_evaluator(parent_->max_loop_iterations_); - auto result = absl::make_unique(reduce->shape()); + absl::InlinedVector results(num_args); + for (int64 i = 0; i < num_args; ++i) { + results[i] = Literal(result_shape); + } + Status eval_status; - // For each resulting dimension, calculate and assign computed value. - TF_RETURN_IF_ERROR(result->Populate( - [&](tensorflow::gtl::ArraySlice multi_index) { - ReturnT result_val = init_scalar; - if (!eval_status.ok()) { - return result_val; - } + // For each resulting dimension, calculate and assign computed values. + // This is really wasteful when num_args > 1, since we re-run the + // reduction num_args time. The alternative is to teach Populate() about + // tuples, which we should probably do. + absl::InlinedVector init_scalars(num_args); + for (int i = 0; i < num_args; ++i) { + init_scalars[i] = init_literals[i]->Get({}); + } - std::vector base(arg_dimensions.size()); - for (int64 i = 0; i < multi_index.size(); ++i) { - base[result_to_arg_index[i]] = multi_index[i]; - } + for (int64 input = 0; input < num_args; ++input) { + TF_RETURN_IF_ERROR(results[input].Populate( + [&](absl::Span multi_index) { + if (!eval_status.ok()) { + return init_scalars[input]; + } + absl::InlinedVector result_values(init_scalars.begin(), + init_scalars.end()); + std::vector base(arg_dimensions.size()); + for (int64 i = 0; i < multi_index.size(); ++i) { + base[result_to_arg_index[i]] = multi_index[i]; + } + + // When the reduction is addition of floats, accumulate in a double + // for better precision. Also, avoid creating Literals for the + // intermediate results; it's much faster. + if (ShapeUtil::ElementIsFloating(init_literals[0]->shape()) && + IsScalarAdd(function)) { + CHECK_EQ(num_args, 1); + double computed_result = 0; + auto func = [&](absl::Span input_index) { + computed_result += + GetAsDouble(*arg_literals[0], input_index); + return true; + }; + ShapeUtil::ForEachIndex(arg_literals[0]->shape(), base, + arg_dim_counts, arg_dim_steps, func); + return static_cast(computed_result); + } + auto func = + [&](absl::Span input_index) -> StatusOr { + absl::InlinedVector arg_values(num_args); + for (int64 i = 0; i < num_args; ++i) { + arg_values[i] = arg_literals[i]->Get(input_index); + } - // When the reduction is addition of floats, accumulate in a double - // for better precision. Also, avoid creating Literals for the - // intermediate results; it's much faster. - if (ShapeUtil::ElementIsFloating(init_literal.shape()) && - IsScalarAdd(function)) { - double computed_result = 0; - auto func = [&](tensorflow::gtl::ArraySlice input_index) { - computed_result += GetAsDouble(arg_literal, input_index); + // Evaluate computation with specified literal operands. + absl::InlinedVector embedded_operands; + for (ReturnT value : result_values) { + embedded_operands.push_back( + LiteralUtil::CreateR0(value)); + } + for (ReturnT value : arg_values) { + embedded_operands.push_back( + LiteralUtil::CreateR0(value)); + } + absl::InlinedVector embedded_operands_ptrs( + embedded_operands.size()); + std::transform(embedded_operands.begin(), embedded_operands.end(), + embedded_operands_ptrs.begin(), + [](Literal& literal) { return &literal; }); + + TF_ASSIGN_OR_RETURN(Literal computed_result, + embedded_evaluator.Evaluate( + *function, embedded_operands_ptrs)); + // Clear visit states so that we can use the evaluator again on + // the same computation. + embedded_evaluator.ResetVisitStates(); + // Assign computed result to result_val. + if (!has_tuple_output) { + result_values[0] = computed_result.Get({}); + } else { + for (int64 i = 0; i < num_args; ++i) { + result_values[i] = computed_result.Get( + /*multi_index=*/{}, /*shape_index=*/{i}); + } + } return true; }; - ShapeUtil::ForEachIndex(arg_literal.shape(), base, arg_dim_counts, - arg_dim_steps, func); - return static_cast(computed_result); - } - auto func = [&](tensorflow::gtl::ArraySlice input_index) - -> StatusOr { - auto curr_val = arg_literal.Get(input_index); - - // Evaluate computation with specified literal operands. - auto curr_val_literal = LiteralUtil::CreateR0(curr_val); - auto result_val_literal = - LiteralUtil::CreateR0(result_val); - - TF_ASSIGN_OR_RETURN(std::unique_ptr computed_result, - embedded_evaluator.Evaluate( - *function, {result_val_literal.get(), - curr_val_literal.get()})); - // Clear visit states so that we can use the evaluator again on - // the same computation. - embedded_evaluator.ResetVisitStates(); - // Assign computed result to result_val. - result_val = computed_result->Get({}); - return true; - }; - // Computes one element of the result, reducing all dimensions that - // contribute to that element. - eval_status = ShapeUtil::ForEachIndexWithStatus( - arg_literal.shape(), base, arg_dim_counts, arg_dim_steps, func); - return result_val; - })); - - parent_->evaluated_[reduce] = std::move(result); + // Computes one element of the result, reducing all dimensions that + // contribute to that element. + eval_status = ShapeUtil::ForEachIndexWithStatus( + arg_shape, base, arg_dim_counts, arg_dim_steps, func); + return result_values[input]; + })); + } + if (!has_tuple_output) { + parent_->evaluated_[reduce] = std::move(results[0]); + } else { + Literal tuple_result(reduce->shape()); + for (int64 i = 0; i < num_args; ++i) { + TF_CHECK_OK(tuple_result.MoveFrom(std::move(results[i]), {i})); + } + parent_->evaluated_[reduce] = std::move(tuple_result); + } return eval_status; } @@ -1705,13 +1770,11 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { TF_RET_CHECK(ShapeUtil::IsScalar(init_literal.shape())); auto init_scalar = init_literal.Get({}); - auto result = absl::make_unique(select_and_scatter->shape()); + Literal result(select_and_scatter->shape()); // Initialize result array with the init value. - TF_RETURN_IF_ERROR(result->Populate( - [&](tensorflow::gtl::ArraySlice output_index) { - return init_scalar; - })); + TF_RETURN_IF_ERROR(result.Populate( + [&](absl::Span output_index) { return init_scalar; })); std::vector window_dimension_sizes; for (const auto& window_dimension : window.dimensions()) { @@ -1760,15 +1823,14 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { selected_val = curr_val; selected_index = operand_index; } - curr_val_literal->Set({}, curr_val); - selected_val_literal->Set({}, *selected_val); - std::unique_ptr computed_result = + curr_val_literal.Set({}, curr_val); + selected_val_literal.Set({}, *selected_val); + Literal computed_result = embedded_evaluator .Evaluate( - *select, - {selected_val_literal.get(), curr_val_literal.get()}) + *select, {&selected_val_literal, &curr_val_literal}) .ConsumeValueOrDie(); - bool selected = !computed_result->Get({}); + bool selected = !computed_result.Get({}); if (selected) { selected_val = curr_val; selected_index = operand_index; @@ -1782,22 +1844,23 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { if (std::equal(operand_index.begin(), operand_index.end(), selected_index->begin())) { auto source = source_literal.Get(source_index); - auto scattered = result->Get(operand_index); - source_literal_scatter->Set({}, source); - scattered_literal->Set({}, scattered); - std::unique_ptr computed_result = + auto scattered = result.Get(operand_index); + source_literal_scatter.Set({}, source); + scattered_literal.Set({}, scattered); + Literal computed_result = embedded_evaluator - .Evaluate(*scatter, - {source_literal_scatter.get(), - scattered_literal.get()}) + .Evaluate( + *scatter, + {&source_literal_scatter, &scattered_literal}) .ConsumeValueOrDie(); - result->Set(operand_index, computed_result->Get({})); + result.Set(operand_index, computed_result.Get({})); // Clear visit states so that the we can use the evaluator again // on the same computation. embedded_evaluator.ResetVisitStates(); } }); - } while (IndexUtil::BumpIndices(source->shape(), &source_index)); + } while ( + IndexUtil::BumpIndices(source->shape(), absl::MakeSpan(source_index))); parent_->evaluated_[select_and_scatter] = std::move(result); return Status::OK(); @@ -1841,10 +1904,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { DimensionVector operand_index(ShapeUtil::Rank(operand_literal.shape())); HloEvaluator embedded_evaluator(parent_->max_loop_iterations_); - auto result = absl::make_unique(reduce_window->shape()); + Literal result(reduce_window->shape()); // For each resulting dimension, calculate and assign computed value. - TF_RETURN_IF_ERROR(result->Populate( - [&](tensorflow::gtl::ArraySlice output_index) { + TF_RETURN_IF_ERROR( + result.Populate([&](absl::Span output_index) { ReturnT result_val = init_scalar; std::fill(window_index.begin(), window_index.end(), 0); @@ -1860,18 +1923,17 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { LiteralUtil::CreateR0(curr_val); const auto result_val_literal = LiteralUtil::CreateR0(result_val); - std::unique_ptr computed_result = + Literal computed_result = embedded_evaluator .Evaluate( - *function, - {result_val_literal.get(), curr_val_literal.get()}) + *function, {&result_val_literal, &curr_val_literal}) .ConsumeValueOrDie(); // Clear visit states so that the we can use the evaluate again // on the same computation. embedded_evaluator.ResetVisitStates(); - result_val = computed_result->Get({}); + result_val = computed_result.Get({}); }); return result_val; @@ -1886,7 +1948,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { // literal (if there is one) to `reshaped_indices`. StatusOr> ReshapedScatterIndices( int64 index_vector_dim, const Literal& indices, - std::unique_ptr* reshaped_indices) { + Literal* reshaped_indices) { if (indices.shape().dimensions_size() != index_vector_dim) { return std::cref(indices); } @@ -1895,7 +1957,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { indices.shape().dimensions().end()); new_shape.push_back(1); TF_ASSIGN_OR_RETURN(*reshaped_indices, indices.Reshape(new_shape)); - return std::cref(**reshaped_indices); + return std::cref(*reshaped_indices); } // Returns an ShapeUtil::IndexIterationSpace that iterates over the update @@ -1989,13 +2051,13 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { // index_vector_index_ and index_vector on every invocation, we reuse the // same storage for all invocations. // - // This returns an arrayslice into memory owned by the class. - StatusOr> operator()( - tensorflow::gtl::ArraySlice update_index) { + // This returns a Span into memory owned by the class. + StatusOr> operator()( + absl::Span update_index) { PropagateUpdateIndexScatterDimsToIndexVectorIndex(update_index); TF_RETURN_IF_ERROR(FetchIndexVector()); PropagateIndexVectorToInputIndex(); - return tensorflow::gtl::ArraySlice(input_index_); + return absl::Span(input_index_); } private: @@ -2004,7 +2066,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { // update the dim_numbers.index_vector_dim() dimension -- that's the // dimension we iterate over in FetchIndexVector. void PropagateUpdateIndexScatterDimsToIndexVectorIndex( - tensorflow::gtl::ArraySlice update_index) { + absl::Span update_index) { int64 index_vector_index_i = 0; for (int64 i = 0, e = update_index.size(); i < e; i++) { if (!update_dim_is_scatter_dims_[i]) { @@ -2059,7 +2121,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { // The index vector fetched from scatter_indices_. std::vector index_vector_; - // The result computed by this functor. operator() returns an ArraySlice + // The result computed by this functor. operator() returns a Span // into this vector. std::vector input_index_; @@ -2112,11 +2174,11 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { // scatter input index on every invocation we reuse the same storage for the // result (input_index_), mutating it in place. // - // This returns an arrayslice into memory owned by the class. - StatusOr> operator()( - tensorflow::gtl::ArraySlice update_index) { + // This returns a Span into memory owned by the class. + StatusOr> operator()( + absl::Span update_index) { PropagateUpdateIndexWindowDimsToInputIndex(update_index); - return tensorflow::gtl::ArraySlice(input_index_); + return absl::Span(input_index_); } // Returns for a given 'input_dim' the corresponding update dimension index, @@ -2129,7 +2191,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { // Propagates window dimensions from the update index to input_index_ by // mutating input_index_ in place. void PropagateUpdateIndexWindowDimsToInputIndex( - tensorflow::gtl::ArraySlice update_index) { + absl::Span update_index) { for (int64 i = 0, e = input_index_.size(); i < e; i++) { if (input_dim_value_to_update_index_[i] != -1) { input_index_[i] = update_index[input_dim_value_to_update_index_[i]]; @@ -2145,7 +2207,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { // PropagateUpdateIndexWindowDimsToInputIndex. std::vector input_dim_value_to_update_index_; - // The result computed by this functor. operator() returns an ArraySlice + // The result computed by this functor. operator() returns a Span // into this vector. std::vector input_index_; }; @@ -2155,7 +2217,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { scatter->scatter_dimension_numbers(); const Literal& operand = parent_->GetEvaluatedLiteralFor(scatter->operand(0)); - std::unique_ptr reshaped_scatter_indices; + Literal reshaped_scatter_indices; TF_ASSIGN_OR_RETURN(const Literal& scatter_indices, ReshapedScatterIndices(dim_numbers.index_vector_dim(), parent_->GetEvaluatedLiteralFor( @@ -2185,15 +2247,14 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { // Initialize the result with the operand. This makes it easier to handle // the updates even when the indices are repeated. - std::unique_ptr result = operand.CloneToUnique(); + Literal result = operand.Clone(); HloEvaluator embedded_evaluator; auto scatter_inner_loop_body = - [&](tensorflow::gtl::ArraySlice update_window_index, - tensorflow::gtl::ArraySlice input_scatter_index, - tensorflow::gtl::ArraySlice update_scatter_index) - -> StatusOr { + [&](absl::Span update_window_index, + absl::Span input_scatter_index, + absl::Span update_scatter_index) -> StatusOr { TF_ASSIGN_OR_RETURN( - tensorflow::gtl::ArraySlice input_window_index, + absl::Span input_window_index, update_window_index_to_input_index(update_window_index)); for (int i = 0, e = update_index.size(); i < e; i++) { update_index[i] = update_scatter_index[i] + update_window_index[i]; @@ -2209,47 +2270,43 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { // be 1. int64 update_dim_size = update_dim == -1 ? 1 : updates_shape.dimensions(update_dim); - // Clamp the scatter index so that the scatter region fits in the - // operand. input_scatter_index_clamped[i] = - // clamp(input_scatter_index[i], 0, - // operand_shape.dimensions(i) - - // update_dim_size); - input_scatter_index_clamped[i] = - std::min(operand_shape.dimensions(i) - update_dim_size, - std::max(0LL, input_scatter_index[i])); + // If any part of the update region is out-of-bounds, then do not + // perform any update on the input. + if ((input_scatter_index[i] < 0) || + (input_scatter_index[i] > + operand_shape.dimensions(i) - update_dim_size)) { + return true; + } } for (int i = 0, e = input_index.size(); i < e; i++) { - input_index[i] = input_scatter_index_clamped[i] + input_window_index[i]; - DCHECK_GE(input_index[i], 0); - DCHECK_LT(input_index[i], operand_shape.dimensions(i)); + input_index[i] = input_scatter_index[i] + input_window_index[i]; } auto result_value_literal = - LiteralUtil::CreateR0(result->Get(input_index)); + LiteralUtil::CreateR0(result.Get(input_index)); auto update_value_literal = LiteralUtil::CreateR0(updates.Get(update_index)); - std::unique_ptr updated_result = + Literal updated_result = embedded_evaluator .Evaluate( *scatter->to_apply(), - {result_value_literal.get(), update_value_literal.get()}) + {&result_value_literal, &update_value_literal}) .ConsumeValueOrDie(); // Clear visit states so that the we can use the evaluate again on the // same computation. embedded_evaluator.ResetVisitStates(); - result->Set(input_index, updated_result->Get({})); + result.Set(input_index, updated_result.Get({})); return true; }; auto scatter_outer_loop_body = - [&](tensorflow::gtl::ArraySlice update_scatter_index) - -> StatusOr { + [&](absl::Span update_scatter_index) -> StatusOr { TF_ASSIGN_OR_RETURN( - tensorflow::gtl::ArraySlice input_scatter_index, + absl::Span input_scatter_index, update_scatter_index_to_input_index(update_scatter_index)); TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus( updates_shape, window_indices_iteration_space, - [&](tensorflow::gtl::ArraySlice update_window_index) { + [&](absl::Span update_window_index) { return scatter_inner_loop_body( update_window_index, input_scatter_index, update_scatter_index); })); @@ -2277,7 +2334,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { const int64 rank = ShapeUtil::Rank(operand->shape()); const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); - auto func = [&](tensorflow::gtl::ArraySlice out_index) { + auto func = [&](absl::Span out_index) { DimensionVector operand_index(rank); for (int64 i = 0; i < rank; ++i) { operand_index[i] = @@ -2286,9 +2343,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return operand_literal.Get(operand_index); }; - auto result = LiteralUtil::CreateFromDimensions( - shape.element_type(), AsInt64Slice(shape.dimensions())); - TF_RETURN_IF_ERROR(result->Populate(func)); + Literal result(shape); + TF_RETURN_IF_ERROR(result.Populate(func)); parent_->evaluated_[slice] = std::move(result); return Status::OK(); } @@ -2493,11 +2549,21 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { std::is_same::value || std::is_same::value || std::is_same::value>::type* = nullptr> - Status HandleIota(HloInstruction* iota) { - auto result = absl::make_unique(iota->shape()); - auto data = result->data(); + Status HandleIota(HloInstruction* instruction) { + auto* iota = Cast(instruction); + std::vector data(iota->shape().dimensions(iota->iota_dimension())); std::iota(data.begin(), data.end(), 0); - parent_->evaluated_[iota] = std::move(result); + auto result = LiteralUtil::CreateR1(data); + + if (ShapeUtil::Rank(iota->shape()) > 1) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[iota], + result.Broadcast(iota->shape(), {iota->iota_dimension()})); + } else { + TF_RET_CHECK(ShapeUtil::Rank(iota->shape()) == 1); + parent_->evaluated_[iota] = std::move(result); + } + return Status::OK(); } template & window_count_index, + const absl::Span& window_count_index, const std::function&)>& f) { const int64 rank = ShapeUtil::Rank(base_shape); DimensionVector window_index(rank); @@ -2547,8 +2613,17 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { std::vector base_index(rank); bool out_of_bound = false; for (int64 i = 0; i < rank; ++i) { - base_index[i] = window_count_index[i] * window.dimensions(i).stride() + - window_index[i] - window.dimensions(i).padding_low(); + base_index[i] = + window_count_index[i] * window.dimensions(i).stride() + + window_index[i] * window.dimensions(i).window_dilation() - + window.dimensions(i).padding_low(); + // We are not in the base area if the dilation placed us out of bounds. + if (base_index[i] % window.dimensions(i).base_dilation() != 0) { + out_of_bound = true; + break; + } + // Apply the dilation to the base area. + base_index[i] /= window.dimensions(i).base_dilation(); if (base_index[i] < 0 || base_index[i] >= base_shape.dimensions(i)) { out_of_bound = true; break; @@ -2557,13 +2632,14 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { if (!out_of_bound) { f(base_index); } - } while (IndexUtil::BumpIndices(window_shape, &window_index)); + } while ( + IndexUtil::BumpIndices(window_shape, absl::MakeSpan(window_index))); } template - StatusOr> DynamicSlice( - const Literal& operand_literal, const Literal& start_indices_literal, - const Shape& result_shape) { + StatusOr DynamicSlice(const Literal& operand_literal, + const Literal& start_indices_literal, + const Shape& result_shape) { auto start_indices_typed = start_indices_literal.data(); std::vector start(start_indices_typed.begin(), start_indices_typed.end()); @@ -2576,9 +2652,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } std::vector operand_indices(start.size()); - auto result = absl::make_unique(result_shape); - TF_RETURN_IF_ERROR(result->Populate( - [&](tensorflow::gtl::ArraySlice multi_index) { + Literal result(result_shape); + TF_RETURN_IF_ERROR( + result.Populate([&](absl::Span multi_index) { for (int64 i = 0; i < operand_indices.size(); ++i) { CHECK_GE(multi_index[i] + start[i], 0); operand_indices[i] = multi_index[i] + start[i]; @@ -2592,12 +2668,12 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } template - StatusOr> DynamicUpdateSlice( - const Literal& operand_literal, const Literal& update_literal, - const Literal& start_indices_literal) { - auto result = operand_literal.CloneToUnique(); + StatusOr DynamicUpdateSlice(const Literal& operand_literal, + const Literal& update_literal, + const Literal& start_indices_literal) { + auto result = operand_literal.Clone(); auto start_indices_typed = start_indices_literal.data(); - const auto rank = ShapeUtil::Rank(result->shape()); + const auto rank = ShapeUtil::Rank(result.shape()); std::vector start(start_indices_typed.begin(), start_indices_typed.end()); // Clamp the update start indices so the slice is in-bounds w.r.t the @@ -2605,15 +2681,15 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { for (int64 i = 0; i < rank; ++i) { start[i] = std::min( std::max(0, start[i]), - result->shape().dimensions(i) - update_literal.shape().dimensions(i)); + result.shape().dimensions(i) - update_literal.shape().dimensions(i)); } std::vector result_index(rank, 0); - auto func = [&](tensorflow::gtl::ArraySlice update_index) { + auto func = [&](absl::Span update_index) { std::transform(update_index.begin(), update_index.end(), start.begin(), result_index.begin(), std::plus()); - result->Set(result_index, - update_literal.Get(update_index)); + result.Set(result_index, + update_literal.Get(update_index)); return true; }; @@ -2626,7 +2702,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return std::move(result); } - StatusOr> ElementWiseUnaryOp( + StatusOr ElementWiseUnaryOp( HloInstruction* instruction, const std::function& unary_op) { const Literal& operand_literal = @@ -2639,7 +2715,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return std::move(result_literal); } - StatusOr> ElementWiseBinaryOp( + StatusOr ElementWiseBinaryOp( HloInstruction* instruction, const std::function& binary_op) { @@ -2654,18 +2730,17 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return Unimplemented( "Implicit broadcasting is currently unsupported in HLO evaluator " "Shape Mismatch: %s vs %s vs %s: ", - ShapeUtil::HumanString(shape).c_str(), - ShapeUtil::HumanString(lhs->shape()).c_str(), - ShapeUtil::HumanString(rhs->shape()).c_str()); + ShapeUtil::HumanString(shape), ShapeUtil::HumanString(lhs->shape()), + ShapeUtil::HumanString(rhs->shape())); } const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); - auto result = absl::make_unique(shape); + Literal result(shape); - TF_RETURN_IF_ERROR(result->Populate( - [&](tensorflow::gtl::ArraySlice multi_index) { + TF_RETURN_IF_ERROR( + result.Populate([&](absl::Span multi_index) { return ConvertBinaryFunction(binary_op)( lhs_literal.Get(multi_index), rhs_literal.Get(multi_index)); @@ -2674,7 +2749,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } template - StatusOr> ElementwiseTernaryOp( + StatusOr ElementwiseTernaryOp( HloInstruction* instruction, const std::function& ternary_op) { const auto shape = instruction->shape(); @@ -2690,20 +2765,19 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return Unimplemented( "Implicit broadcasting is currently unsupported in HLO evaluator " "Shape Mismatch: %s vs %s vs %s vs %s: ", - ShapeUtil::HumanString(shape).c_str(), - ShapeUtil::HumanString(lhs->shape()).c_str(), - ShapeUtil::HumanString(rhs->shape()).c_str(), - ShapeUtil::HumanString(ehs->shape()).c_str()); + ShapeUtil::HumanString(shape), ShapeUtil::HumanString(lhs->shape()), + ShapeUtil::HumanString(rhs->shape()), + ShapeUtil::HumanString(ehs->shape())); } const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); const Literal& ehs_literal = parent_->GetEvaluatedLiteralFor(ehs); - auto result = absl::make_unique(shape); + Literal result(shape); - TF_RETURN_IF_ERROR(result->Populate( - [&](tensorflow::gtl::ArraySlice multi_index) { + TF_RETURN_IF_ERROR( + result.Populate([&](absl::Span multi_index) { return ternary_op(lhs_literal.Get(multi_index), rhs_literal.Get(multi_index), ehs_literal.Get(multi_index)); diff --git a/tensorflow/compiler/xla/service/hlo_execution_profile.cc b/tensorflow/compiler/xla/service/hlo_execution_profile.cc index de3d7a167752f0de790585e50874dd6d2904bd37..ce4cad42355ec5881f2ae14f4dd52a0588d51cf7 100644 --- a/tensorflow/compiler/xla/service/hlo_execution_profile.cc +++ b/tensorflow/compiler/xla/service/hlo_execution_profile.cc @@ -90,8 +90,9 @@ std::unique_ptr CreateHloProfilePrinterData( HloInstructionInfo* instruction_info = computation_info->add_instruction_infos(); instruction_info->set_long_name(hlo->ToString()); - instruction_info->set_short_name( - hlo->ToString(HloPrintOptions().set_compact_operands(true))); + instruction_info->set_short_name(hlo->ToString( + HloPrintOptions().set_compact_operands(true).set_print_operand_names( + false))); instruction_info->set_category(hlo->ToCategory()); instruction_info->set_flop_count(cost_analysis.flop_count(*hlo)); instruction_info->set_transcendental_count( diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index 59c628e945a4e27c1b0f447d165babec4898b81c..13a74fd8a115c5dc9a9518b226dfee4445cc7180 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -28,6 +28,7 @@ limitations under the License. #include "absl/strings/match.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "absl/strings/str_replace.h" #include "absl/types/optional.h" @@ -44,7 +45,6 @@ limitations under the License. #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/strings/numbers.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/regexp.h" @@ -57,32 +57,12 @@ using absl::nullopt; using absl::optional; using absl::StrAppend; using absl::StrCat; +using absl::StrFormat; using absl::StrJoin; using tensorflow::Env; using tensorflow::WriteStringToFile; using tensorflow::io::JoinPath; -// Helpers for Printf and Appendf. -template -struct PrintfConvert { - const T& operator()(const T& t) const { return t; } -}; -template <> -struct PrintfConvert { - const char* operator()(const string& s) const { return s.c_str(); } -}; - -// Like tensorflow::strings::Printf/Appendf, but you don't need to call c_str() -// on strings. -template -string Printf(const char* fmt, const Ts&... ts) { - return tensorflow::strings::Printf(fmt, PrintfConvert()(ts)...); -} -template -void Appendf(string* s, const char* fmt, const Ts&... ts) { - tensorflow::strings::Appendf(s, fmt, PrintfConvert()(ts)...); -} - // Used to indicate how we should treat a given HLOInstruction in the graph. // should we treat it like normal, hide it, and so on? enum NodeFilterResult { @@ -140,12 +120,23 @@ class NodeFilter { std::function filter_; }; +// We arbitrarily set this as the boundary between "large" and "small" +// instructions. +bool IsSmall(const HloInstruction* instr) { + if (ShapeUtil::HasPrimitiveType(instr->shape(), OPAQUE) || + ShapeUtil::HasPrimitiveType(instr->shape(), TOKEN)) { + return true; + } + return ShapeUtil::ElementsInRecursive(instr->shape()) < 4096; +} + // Node color schemes, used by NodeColorAttributes. enum ColorScheme { kBlue, kBrown, kDarkBlue, kDarkGreen, + kDarkOrange, kDarkRed, kGray, kGreen, @@ -178,6 +169,10 @@ NodeColors NodeColorsForScheme(ColorScheme color) { return NodeColors{"filled", "#1565c0", "#003c8f", "white"}; case kDarkGreen: return NodeColors{"filled", "#2e7d32", "#005005", "white"}; + case kDarkOrange: + // This is more of a "medium" orange, made to look close to kOrange; + // there's probably room for a darker weight if desired. + return NodeColors{"filled", "#ffb74d", "#c88719", "black"}; case kDarkRed: return NodeColors{"filled", "#b71c1c", "#7f0000", "white"}; case kGray: @@ -210,10 +205,9 @@ NodeColors NodeColorsForScheme(ColorScheme color) { string NodeColorAttributes(ColorScheme color) { NodeColors node_colors = NodeColorsForScheme(color); - return Printf( - R"(style="%s", fontcolor="%s", color="%s", fillcolor="%s")", - node_colors.style, node_colors.font_color, node_colors.stroke_color, - node_colors.fill_color); + return StrFormat(R"(style="%s", fontcolor="%s", color="%s", fillcolor="%s")", + node_colors.style, node_colors.font_color, + node_colors.stroke_color, node_colors.fill_color); } // Replaces <> with <>, so that this string is safe(er) for use in a @@ -326,7 +320,7 @@ class HloDotDumper { const DebugOptions& debug_options, bool show_backend_config, const HloExecutionProfile* profile, NodeFilter filter) : computation_(computation), - label_(std::string(label)), + label_(label), debug_options_(debug_options), show_backend_config_(show_backend_config), profile_(profile), @@ -448,7 +442,7 @@ string HloDotDumper::Dump() { } string HloDotDumper::Header() { - const char* fmt = R"(digraph G { + constexpr char fmt[] = R"(digraph G { rankdir = TB; compound = true; label = <%s>; @@ -475,14 +469,13 @@ stylesheet=< string graph_label = StrCat(label_, "
Computation ", computation_->name()); if (computation_->IsFusionComputation()) { - StrAppend(&graph_label, - StrCat(" (in fusion instruction ", - computation_->FusionInstruction()->name(), ")")); + StrAppend(&graph_label, " (in fusion instruction ", + computation_->FusionInstruction()->name(), ")"); } if (profile_ != nullptr) { auto cycles = profile_->total_cycles_executed(*computation_); - Appendf(&graph_label, "
total cycles = %lld (%s)", cycles, - tensorflow::strings::HumanReadableNum(cycles)); + absl::StrAppendFormat(&graph_label, "
total cycles = %d (%s)", cycles, + tensorflow::strings::HumanReadableNum(cycles)); } // Create CSS rules that say, when you hover over the given node or cluster, @@ -509,14 +502,14 @@ stylesheet=< // One could imagine other ways of writing this CSS rule that involve // less duplication, but this way seems to be relatively performant. edge_css_rules.push_back( - Printf(" #%s%d:hover ~ #edge%lld text { fill: %s; }\n" - " #%s%d:hover ~ #edge%lld path { " - "stroke: %s; stroke-width: .2em; }\n" - " #%s%d:hover ~ #edge%lld polygon { " - "fill: %s; stroke: %s; stroke-width: .2em; }\n", - elem_type, elem_id, edge_id, color, // - elem_type, elem_id, edge_id, color, // - elem_type, elem_id, edge_id, color, color)); + StrFormat(" #%s%d:hover ~ #edge%d text { fill: %s; }\n" + " #%s%d:hover ~ #edge%d path { " + "stroke: %s; stroke-width: .2em; }\n" + " #%s%d:hover ~ #edge%d polygon { " + "fill: %s; stroke: %s; stroke-width: .2em; }\n", + elem_type, elem_id, edge_id, color, // + elem_type, elem_id, edge_id, color, // + elem_type, elem_id, edge_id, color, color)); }; // The "to_node" value may be a NULL, indicating that this points to the @@ -559,7 +552,7 @@ stylesheet=< } } - return Printf(fmt, graph_label, StrJoin(edge_css_rules, "\n")); + return StrFormat(fmt, graph_label, StrJoin(edge_css_rules, "\n")); } string HloDotDumper::Footer() { return StrCat(StrJoin(edges_, "\n"), "\n}"); } @@ -600,9 +593,9 @@ string HloDotDumper::DumpSubcomputation(const HloComputation* subcomp, VLOG(2) << "Edge: from " << from->name() << " to " << parent_instr->name() << " as " << next_edge_id_; edge_ids_.insert({{from, parent_instr}, next_edge_id_++}); - const char* edge_fmt = + constexpr char edge_fmt[] = R"(%s -> %s [ltail="%s", style="dashed" tooltip="%s -> %s"];)"; - edges_.push_back(Printf( + edges_.push_back(StrFormat( edge_fmt, InstructionId(from), InstructionId(parent_instr), SubcomputationId(subcomp), subcomp->name(), parent_instr->name())); } @@ -619,9 +612,10 @@ string HloDotDumper::DumpSubcomputation(const HloComputation* subcomp, string subcomp_label, style; if (parent_instr->opcode() == HloOpcode::kFusion) { - subcomp_label = Printf("Fused expression for %s
%s", - HtmlLikeStringSanitize(parent_instr->name()), - HtmlLikeStringSanitize(parent_instr->ToCategory())); + subcomp_label = + StrFormat("Fused expression for %s
%s", + HtmlLikeStringSanitize(parent_instr->name()), + HtmlLikeStringSanitize(parent_instr->ToCategory())); string extra_info = GetInstructionNodeExtraInfo(parent_instr); if (!extra_info.empty()) { StrAppend(&subcomp_label, "
", extra_info); @@ -647,18 +641,18 @@ string HloDotDumper::DumpSubcomputation(const HloComputation* subcomp, strokecolor = highlight ? "#b71c1c" : "#c2c2c2"; } style = - Printf(R"(style="rounded,filled,bold"; fillcolor="%s"; color="%s;")", - fillcolor, strokecolor); + StrFormat(R"(style="rounded,filled,bold"; fillcolor="%s"; color="%s;")", + fillcolor, strokecolor); } else { - subcomp_label = Printf("Subcomputation for %s
%s", - HtmlLikeStringSanitize(parent_instr->name()), - HtmlLikeStringSanitize(subcomp->name())); + subcomp_label = StrFormat("Subcomputation for %s
%s", + HtmlLikeStringSanitize(parent_instr->name()), + HtmlLikeStringSanitize(subcomp->name())); style = "style=rounded; color=black;"; } string comp_body = DumpComputation(subcomp); - const char* computation_fmt = R"(subgraph %s { + constexpr char computation_fmt[] = R"(subgraph %s { %s label = <%s>; labelloc = t; @@ -667,7 +661,7 @@ tooltip = " "; } // %s )"; - return Printf(computation_fmt, id, style, subcomp_label, comp_body, id); + return StrFormat(computation_fmt, id, style, subcomp_label, comp_body, id); } string HloDotDumper::DumpComputation(const HloComputation* comp) { @@ -718,11 +712,11 @@ string HloDotDumper::DumpRootTag() { VLOG(2) << "Adding edge from " << from->name() << " to root tag as " << next_edge_id_; edge_ids_.insert({{from, to}, next_edge_id_++}); - edges_.push_back(Printf(R"(%s -> %s [tooltip=" "];)", from_id, to_id)); + edges_.push_back(StrFormat(R"(%s -> %s [tooltip=" "];)", from_id, to_id)); - return Printf(R"(%s [label=<%s>, shape=%s, tooltip=" ", %s];)" - "\n", - to_id, node_body, node_shape, NodeColorAttributes(color)); + return StrFormat(R"(%s [label=<%s>, shape=%s, tooltip=" ", %s];)" + "\n", + to_id, node_body, node_shape, NodeColorAttributes(color)); } static const HloConstantInstruction* TryGetFusionParameterConstant( @@ -817,10 +811,10 @@ string HloDotDumper::DumpInstruction(const HloInstruction* instr) { } } - return Printf(R"(%s [label=<%s>, shape=%s, tooltip="%s", %s];)" - "\n", - InstructionId(instr), node_body, node_shape, node_metadata, - NodeColorAttributes(color)); + return StrFormat(R"(%s [label=<%s>, shape=%s, tooltip="%s", %s];)" + "\n", + InstructionId(instr), node_body, node_shape, node_metadata, + NodeColorAttributes(color)); } string HloDotDumper::GetInstructionNodeInlinedOperands( @@ -833,7 +827,7 @@ string HloDotDumper::GetInstructionNodeInlinedOperands( // enumerates all of its empty dimensions (e.g. "{ { {}, {} }, ..."), which // is just noise. if (ShapeUtil::IsZeroElementArray(shape)) { - return Printf("{} (%s)", ShapeUtil::HumanString(constant->shape())); + return StrFormat("{} (%s)", ShapeUtil::HumanString(constant->shape())); } // Print the literal value of constants with <= K elements. @@ -848,8 +842,8 @@ string HloDotDumper::GetInstructionNodeInlinedOperands( // collected from profiling tools. Those constants may not have a valid // literal. if (elem_count.has_value() && *elem_count <= 8 && constant->HasLiteral()) { - return Printf("%s (%s)", constant->literal().ToString(), - ShapeUtil::HumanString(constant->shape())); + return StrFormat("%s (%s)", constant->literal().ToString(), + ShapeUtil::HumanString(constant->shape())); } // Otherwise, print e.g. "%constant.42 (s32[100])". @@ -859,8 +853,8 @@ string HloDotDumper::GetInstructionNodeInlinedOperands( } else { constant_name = StrCat("constant ", constant->name()); } - return Printf("%s %s", constant_name, - ShapeUtil::HumanString(constant->shape())); + return StrFormat("%s %s", constant_name, + ShapeUtil::HumanString(constant->shape())); }; std::vector lines; @@ -881,7 +875,7 @@ string HloDotDumper::GetInstructionNodeInlinedOperands( TryGetFusionParameterConstant(operand)) { operand_str = stringify_constant(constant); } else { - operand_str = Printf("Parameter %lld", operand->parameter_number()); + operand_str = StrFormat("Parameter %d", operand->parameter_number()); } } else { operand_str = operand->name(); @@ -890,9 +884,9 @@ string HloDotDumper::GetInstructionNodeInlinedOperands( if (operand_str) { if (instr->operand_count() > 1) { - lines.push_back(Printf("operand %lld = %s", i, *operand_str)); + lines.push_back(StrFormat("operand %d = %s", i, *operand_str)); } else { - lines.push_back(Printf("operand = %s", *operand_str)); + lines.push_back(StrFormat("operand = %s", *operand_str)); } } } @@ -913,7 +907,10 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { sharding_colors_.emplace(instr->sharding(), color); return color; } - const auto kParameterColor = kOrange; + + // Choose different weights of orange for small vs large parameters. This + // distinction is often important, especially in fusion nodes. + auto parameter_color = IsSmall(instr) ? kOrange : kDarkOrange; // Special case: If this instruction has a parameter merged into it, paint it // the same color as a parameter. Unless the merged-in parameter is a @@ -925,7 +922,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { ShouldMergeIntoUsers(operand) && TryGetFusionParameterConstant(operand) == nullptr; })) { - return kParameterColor; + return parameter_color; } // Pick different colors or shapes for instructions which are particularly @@ -1035,7 +1032,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kReducePrecision: return kRed; case HloOpcode::kParameter: - return kParameterColor; + return parameter_color; case HloOpcode::kBatchNormGrad: case HloOpcode::kBatchNormInference: case HloOpcode::kBatchNormTraining: @@ -1049,6 +1046,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { return kGray; case HloOpcode::kCrossReplicaSum: case HloOpcode::kAllToAll: + case HloOpcode::kCollectivePermute: case HloOpcode::kInfeed: case HloOpcode::kOutfeed: case HloOpcode::kRecv: @@ -1079,13 +1077,13 @@ string HloDotDumper::GetInstructionNodeShape(const HloInstruction* instr) { string HloDotDumper::GetInstructionNodeLabel(const HloInstruction* instr) { // If we have a parameter, put the param number in the name. if (instr->opcode() == HloOpcode::kParameter) { - return Printf("Parameter %lld", instr->parameter_number()); + return StrFormat("Parameter %d", instr->parameter_number()); } // The HLO instruction name contains usually the opcode, e.g. "%add.42" is // an add instruction. In this case we render just the name. if (absl::StartsWith(instr->name(), HloOpcodeString(instr->opcode()))) { - return Printf("%s", HtmlLikeStringSanitize(instr->name())); + return StrFormat("%s", HtmlLikeStringSanitize(instr->name())); } string extended_opcode = StrCat(HloOpcodeString(instr->opcode()), @@ -1093,8 +1091,8 @@ string HloDotDumper::GetInstructionNodeLabel(const HloInstruction* instr) { ? "" : StrCat(":", xla::ToString(instr->fusion_kind()))); // If the name does not contain the opcode, render both. - return Printf("%s
%s", HtmlLikeStringSanitize(extended_opcode), - HtmlLikeStringSanitize(instr->name())); + return StrFormat("%s
%s", HtmlLikeStringSanitize(extended_opcode), + HtmlLikeStringSanitize(instr->name())); } string HloDotDumper::GetInstructionNodeMetadata(const HloInstruction* instr) { @@ -1103,16 +1101,16 @@ string HloDotDumper::GetInstructionNodeMetadata(const HloInstruction* instr) { lines.push_back(HtmlLikeStringSanitize(instr->metadata().op_name())); } if (!instr->metadata().op_type().empty()) { - lines.push_back(Printf( + lines.push_back(StrFormat( "op_type: %s", HtmlLikeStringSanitize(instr->metadata().op_type()))); } if (!instr->metadata().source_file().empty() && instr->metadata().source_line() != 0) { - lines.push_back(Printf("op_type: %s", instr->metadata().source_file(), - instr->metadata().source_line())); + lines.push_back(StrFormat("op_type: %s:%d", instr->metadata().source_file(), + instr->metadata().source_line())); } - return StrJoin(lines, "
"); + return StrJoin(lines, "\n"); } string HloDotDumper::GetInstructionNodeBackendConfig( @@ -1164,7 +1162,7 @@ string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) { lines.push_back(instr_shape); } if (debug_options_.xla_hlo_graph_addresses()) { - lines.push_back(Printf("[%p]", instr)); + lines.push_back(StrFormat("[%p]", instr)); } if (profile_ != nullptr) { double hlo_cycles_executed = profile_->GetCyclesTakenBy(*instr); @@ -1172,27 +1170,13 @@ string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) { profile_->total_cycles_executed(*instr->parent()); if (hlo_cycles_executed > 0 && total_cycles_executed > 0) { lines.push_back( - Printf("%% of cycles executed=%.2f", - 100 * hlo_cycles_executed / total_cycles_executed)); + StrFormat("%% of cycles executed=%.2f", + 100 * hlo_cycles_executed / total_cycles_executed)); } } return StrJoin(lines, "
"); } -// Gets the total number of array elements in the given shape. For tuples, this -// is the sum of all the sizes of all of the array elements recursively in the -// tuple. -static int64 TotalElementsInShape(const Shape& shape) { - int64 elems = 0; - ShapeUtil::ForEachSubshape( - shape, [&](const Shape& subshape, const ShapeIndex& /*index*/) { - if (ShapeUtil::IsArray(subshape)) { - elems += ShapeUtil::ElementsIn(subshape); - } - }); - return elems; -} - void HloDotDumper::AddInstructionIncomingEdges(const HloInstruction* instr) { auto add_edge = [&](const HloInstruction* from, const HloInstruction* to, int64 operand_num, bool control_edge = false) { @@ -1208,20 +1192,19 @@ void HloDotDumper::AddInstructionIncomingEdges(const HloInstruction* instr) { string edge_label; if (instr->operand_count() > 1 && !control_edge) { - edge_label = Printf(R"( headlabel="%lld", labeldistance=2)", operand_num); + edge_label = + StrFormat(R"( headlabel="%d", labeldistance=2)", operand_num); } else if (control_edge) { edge_label = "style=\"dotted\" color=\"gray\" label=\"ctrl\""; } // We print "small" arrays using a hollow arrowhead and "large" arrays using - // a filled arrowhead. For now, we use an arbitrary cutoff for what "big" - // means. - bool is_big_array = TotalElementsInShape(from->shape()) >= 4096; - - const char* kEdgeFmt = R"(%s -> %s [arrowhead=%s tooltip="%s -> %s" %s];)"; - edges_.push_back(Printf(kEdgeFmt, InstructionId(from), InstructionId(to), - (is_big_array ? "normal" : "empty"), from->name(), - to->name(), edge_label)); + // a filled arrowhead. + constexpr char kEdgeFmt[] = + R"(%s -> %s [arrowhead=%s tooltip="%s -> %s" %s];)"; + edges_.push_back(StrFormat(kEdgeFmt, InstructionId(from), InstructionId(to), + (IsSmall(from) ? "empty" : "normal"), + from->name(), to->name(), edge_label)); }; // Add edges from instr's operands to instr. Parameters within fusion @@ -1262,11 +1245,11 @@ string HloDotDumper::GetInstructionTrivialComputationStr( continue; } if (instr->called_computations().size() == 1) { - lines.push_back(Printf("Subcomputation: %s", - HtmlLikeStringSanitize(*computation_type))); + lines.push_back(StrFormat("Subcomputation: %s", + HtmlLikeStringSanitize(*computation_type))); } else { - lines.push_back(Printf("Subcomputation %lld: %s", i, - HtmlLikeStringSanitize(*computation_type))); + lines.push_back(StrFormat("Subcomputation %d: %s", i, + HtmlLikeStringSanitize(*computation_type))); } } return StrJoin(lines, "
"); diff --git a/tensorflow/compiler/xla/service/hlo_input_output_alias_config.cc b/tensorflow/compiler/xla/service/hlo_input_output_alias_config.cc new file mode 100644 index 0000000000000000000000000000000000000000..8128fad07ca0b9c3883ed93c6e1c8e977e990cb4 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_input_output_alias_config.cc @@ -0,0 +1,182 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" + +namespace xla { +Status HloInputOutputAliasConfig::SetUpAlias(const ShapeIndex& output_index, + int64 param_number, + const ShapeIndex& param_index) { + TF_RET_CHECK(ShapeUtil::IndexIsValid(alias_.shape(), output_index)) + << absl::StrCat("Tring to set up alias at ", output_index.ToString(), + " which is an invalid index for shape ", + ShapeUtil::HumanString(alias_.shape())); + // Output can't be aliased with multiple parameters. + TF_RET_CHECK(!alias_.element(output_index)) << absl::StrFormat( + "Trying to set up output alias for param %lld at %s but failed: output " + "index %s is already aliased with param %lld at %s", + param_number, param_index.ToString(), output_index.ToString(), + alias_.element(output_index)->first, + alias_.element(output_index)->second.ToString()); + (*alias_.mutable_element(output_index)) = + std::make_pair(param_number, param_index); + return Status::OK(); +} + +HloInputOutputAliasProto HloInputOutputAliasConfig::ToProto() const { + HloInputOutputAliasProto result; + alias_.ForEachElement( + [&](const ShapeIndex& index, + const absl::optional>& data) { + if (data) { + HloInputOutputAliasProto::AliasEntryProto entry; + for (int64 i : index) { + entry.add_output_shape_index(i); + } + entry.set_parameter_number(data->first); + for (int64 i : data->second) { + entry.add_parameter_shape_index(i); + } + result.add_entries()->Swap(&entry); + } + }); + return result; +} + +StatusOr HloInputOutputAliasConfig::CreateFromProto( + const Shape& output_shape, const HloInputOutputAliasProto& proto) { + HloInputOutputAliasConfig result(output_shape); + for (const HloInputOutputAliasProto::AliasEntryProto& entry : + proto.entries()) { + ShapeIndex output_index(entry.output_shape_index().begin(), + entry.output_shape_index().end()); + + int64 param_number = entry.parameter_number(); + ShapeIndex param_index(entry.parameter_shape_index().begin(), + entry.parameter_shape_index().end()); + TF_RETURN_IF_ERROR( + result.SetUpAlias(output_index, param_number, param_index)); + } + + return result; +} + +string HloInputOutputAliasConfig::ToString() const { + std::vector pieces; + pieces.push_back("HloInputOutputAliasConfig"); + + ForEachAlias([&](const ShapeIndex& output_index, int64 param_number, + const ShapeIndex& param_index) { + pieces.push_back(absl::StrFormat( + " OutputIndex %s is aliased with parameter %lld at %s:", + output_index.ToString(), param_number, param_index.ToString())); + }); + + return absl::StrJoin(pieces, "\n"); +} + +bool HloInputOutputAliasConfig::ParameterHasAlias( + int64 param_number, const ShapeIndex& param_index) const { + bool output = false; + alias_.ForEachElement( + [&](const xla::ShapeIndex&, + absl::optional> alias) { + if (alias && alias->first == param_number && + alias->second == param_index) { + output = true; + } + }); + return output; +} + +absl::optional HloInputOutputAliasConfig::GetAliasedOutput( + int64 param_number, const ShapeIndex& param_index) const { + absl::optional output; + alias_.ForEachElement( + [&](const xla::ShapeIndex& output_index, + absl::optional> alias) { + if (alias && alias->first == param_number && + alias->second == param_index) { + output = output_index; + } + }); + return output; +} + +absl::optional> +HloInputOutputAliasConfig::GetAliasedParameter( + const ShapeIndex& output_index) const { + CHECK(ShapeUtil::IndexIsValid(alias_.shape(), output_index)); + return alias_.element(output_index); +} + +void HloInputOutputAliasConfig::ForEachAlias(AliasFn fn) const { + alias_.ForEachElement( + [&](const ShapeIndex& output_index, + absl::optional> aliased) { + if (aliased) { + fn(output_index, aliased->first, aliased->second); + } + }); +} + +Status HloInputOutputAliasConfig::ForEachAliasWithStatus( + AliasFnWithStatus fn) const { + return alias_.ForEachElementWithStatus( + [&](const ShapeIndex& output_index, + absl::optional> aliased) { + if (aliased) { + TF_RETURN_IF_ERROR(fn(output_index, aliased->first, aliased->second)); + } + return Status::OK(); + }); +} + +Status HloInputOutputAliasConfig::Verify(const HloModule& module) const { + std::vector> param_has_seen; + const HloComputation* entry = module.entry_computation(); + for (int64 i = 0; i < entry->num_parameters(); ++i) { + HloInstruction* param = entry->parameter_instruction(i); + param_has_seen.emplace_back(param->shape()); + } + return ForEachAliasWithStatus([&](const ShapeIndex& output_index, + int64 param_number, + const ShapeIndex& param_index) -> Status { + const HloInstruction* root = entry->root_instruction(); + + const Shape& param_shape = + entry->parameter_instruction(param_number)->shape(); + const Shape& output_shape = root->shape(); + TF_RET_CHECK(entry->num_parameters() > param_number); + TF_RET_CHECK(ShapeUtil::IndexIsValid(param_shape, param_index)); + TF_RET_CHECK(ShapeUtil::IndexIsValid(output_shape, output_index)); + + // Check each param_number and param_index pair only show up once. No + // input can be aliased with output buffers. + TF_RET_CHECK(param_has_seen[param_number].element(param_index) == false); + + *(param_has_seen[param_number].mutable_element(param_index)) = true; + + return Status::OK(); + }); +} + +std::ostream& operator<<(std::ostream& out, + const HloInputOutputAliasConfig& config) { + out << config.ToString(); + return out; +} +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_input_output_alias_config.h b/tensorflow/compiler/xla/service/hlo_input_output_alias_config.h new file mode 100644 index 0000000000000000000000000000000000000000..0fae75842ba28da5dcb59e5952cd60c1d1c5ea68 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_input_output_alias_config.h @@ -0,0 +1,102 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INPUT_OUTPUT_ALIAS_CONFIG_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INPUT_OUTPUT_ALIAS_CONFIG_H_ + +#include + +#include "absl/types/optional.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/compiler/xla/shape_tree.h" +#include "tensorflow/compiler/xla/shape_util.h" + +namespace xla { + +class HloModule; + +// This class specifies the alias map from output index to parameter number and +// parameter index in the entry computation. +class HloInputOutputAliasConfig { + public: + HloInputOutputAliasConfig() = default; + + explicit HloInputOutputAliasConfig(Shape shape) : alias_(shape) {} + + virtual ~HloInputOutputAliasConfig() = default; + + // Sets up alias config from `output_index` to `param_index` at + // `param_number`. + Status SetUpAlias(const ShapeIndex& output_index, int64 param_number, + const ShapeIndex& param_index); + + // Returns true if the given parameter is aliased with one of the output + // buffers. + bool ParameterHasAlias(int64 param_number, + const ShapeIndex& param_index) const; + + // (De)Serializes an HloInputOutoutAliasConfig to/from an + // HloInputOutoutAliasProto. + HloInputOutputAliasProto ToProto() const; + + static StatusOr CreateFromProto( + const Shape& output_shape, const HloInputOutputAliasProto& proto); + + // Returns the output index that the given parameter and parameter index is + // aliased with. A nullopt is returned if there is no output that is aliased + // with the parameter number and index. + absl::optional GetAliasedOutput( + int64 param_number, const ShapeIndex& param_index) const; + + // Returns the number of parameter and index of the parameter buffer that the + // given output buffer index is aliased with. A nullopt is returned if there + // is no parameter is aliased with the specific output. + absl::optional> GetAliasedParameter( + const ShapeIndex& output_index) const; + + using AliasFn = + std::function; + + // Iterates through each aliased output and input. + void ForEachAlias(AliasFn fn) const; + + using AliasFnWithStatus = + std::function; + + // Verifies that the given config is valid for the given module. + // Specifically, the config's input and output should be in-bound and size of + // the aliased buffers should match. + Status Verify(const HloModule& module) const; + + Status ForEachAliasWithStatus(AliasFnWithStatus fn) const; + + string ToString() const; + + private: + // A ShapeTree which indicates the list of buffers that's expected to be + // aliased. The key on this shape tree represents the output index. The value + // is a pair of parameter number and index into the buffer. If the value is + // nullopt, it means there is no parameter aliasing for this output. + ShapeTree>> alias_; +}; + +std::ostream& operator<<(std::ostream& out, + const HloInputOutputAliasConfig& config); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INPUT_OUTPUT_ALIAS_CONFIG_H_ diff --git a/tensorflow/compiler/xla/service/hlo_input_output_alias_config_test.cc b/tensorflow/compiler/xla/service/hlo_input_output_alias_config_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..3b61ff04e6d7eeaa5876775fa18a85af82164b3d --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_input_output_alias_config_test.cc @@ -0,0 +1,184 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h" + +#include +#include + +#include "absl/algorithm/container.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_dce.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_ordering.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace xla { +namespace { +class HloInputOutputAliasConfigTest : public HloTestBase { + protected: + void expect_aliased(const ShapeIndex& output_index, int64 param_number, + const ShapeIndex& param_index, + const HloInputOutputAliasConfig& config) { + absl::optional aliased_output = + config.GetAliasedOutput(param_number, param_index); + + EXPECT_TRUE(aliased_output); + EXPECT_EQ(aliased_output.value(), output_index); + + absl::optional> aliased_param = + config.GetAliasedParameter(output_index); + + EXPECT_TRUE(aliased_param); + EXPECT_EQ(aliased_param.value(), std::make_pair(param_number, param_index)); + } + + void expect_not_aliased(const ShapeIndex& output_index, int64 param_number, + const ShapeIndex& param_index, + const HloInputOutputAliasConfig& config) { + absl::optional aliased_output = + config.GetAliasedOutput(param_number, param_index); + + EXPECT_FALSE(aliased_output && aliased_output == output_index); + + absl::optional> aliased_param = + config.GetAliasedParameter(output_index); + + EXPECT_FALSE(aliased_param && aliased_param->first == param_number && + aliased_param->second == param_index); + } +}; + +TEST_F(HloInputOutputAliasConfigTest, SimpleAliasing) { + const string module_str = R"( +HloModule TEST + +ENTRY main { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT root = (f32[], f32[]) tuple(%a, %b) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(module_str)); + + HloInputOutputAliasConfig config( + module->entry_computation()->root_instruction()->shape()); + + TF_ASSERT_OK(config.SetUpAlias(/*output_index=*/{0}, /*param_number=*/1, + /*param_index=*/{})); + + expect_aliased(/*output_index=*/{0}, /*param_number=*/1, + /*param_index=*/{}, config); + + expect_not_aliased(/*output_index=*/{1}, /*param_number=*/1, + /*param_index=*/{}, config); + + expect_not_aliased(/*output_index=*/{0}, /*param_number=*/0, + /*param_index=*/{}, config); +} + +TEST_F(HloInputOutputAliasConfigTest, SimpleAliasingWithTupleInput) { + const string module_str = R"( +HloModule TEST + +ENTRY main { + param = (f32[], f32[]) parameter(0) + gte1 = f32[] get-tuple-element(%param), index=0 + gte2 = f32[] get-tuple-element(%param), index=1 + ROOT root = (f32[], f32[]) tuple(%gte1, %gte2) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(module_str)); + + HloInputOutputAliasConfig config( + module->entry_computation()->root_instruction()->shape()); + + TF_ASSERT_OK(config.SetUpAlias(/*output_index=*/{0}, /*param_number=*/0, + /*param_index=*/{0})); + + TF_ASSERT_OK(config.SetUpAlias(/*output_index=*/{1}, /*param_number=*/0, + /*param_index=*/{1})); + + expect_aliased(/*output_index=*/{0}, /*param_number=*/0, + /*param_index=*/{0}, config); + + expect_aliased(/*output_index=*/{1}, /*param_number=*/0, + /*param_index=*/{1}, config); + + expect_not_aliased(/*output_index=*/{1}, /*param_number=*/1, + /*param_index=*/{}, config); + + expect_not_aliased(/*output_index=*/{0}, /*param_number=*/0, + /*param_index=*/{}, config); +} + +TEST_F(HloInputOutputAliasConfigTest, InputDoNotAliasTwice) { + const string module_str = R"( +HloModule TEST + +ENTRY main { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT root = (f32[], f32[]) tuple(%a, %b) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(module_str)); + + HloInputOutputAliasConfig config( + module->entry_computation()->root_instruction()->shape()); + + TF_ASSERT_OK(config.SetUpAlias(/*output_index=*/{0}, /*param_number=*/0, + /*param_index=*/{})); + + TF_ASSERT_OK(config.SetUpAlias(/*output_index=*/{1}, /*param_number=*/0, + /*param_index=*/{})); + + ASSERT_IS_NOT_OK(config.Verify(*module)); +} + +TEST_F(HloInputOutputAliasConfigTest, OutputDoNotAliasTwice) { + const string module_str = R"( +HloModule TEST + +ENTRY main { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT root = (f32[], f32[]) tuple(%a, %b) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(module_str)); + + HloInputOutputAliasConfig config( + module->entry_computation()->root_instruction()->shape()); + + TF_ASSERT_OK(config.SetUpAlias(/*output_index=*/{0}, /*param_number=*/0, + /*param_index=*/{})); + + ASSERT_IS_NOT_OK(config.SetUpAlias(/*output_index=*/{0}, /*param_number=*/1, + /*param_index=*/{})); +} +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 2bb9de686ffbcf276f9e92e1894e1fed8fbea129..b6df63c983d7297cb26b9cf528f41fa54a343cd7 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -22,6 +22,8 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" #include "absl/memory/memory.h" #include "absl/strings/ascii.h" @@ -37,14 +39,13 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h" #include "tensorflow/compiler/xla/service/name_uniquer.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/gtl/flatmap.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/platform/human_readable_json.h" #include "tensorflow/core/platform/logging.h" @@ -59,8 +60,8 @@ using absl::StrJoin; /* static */ StatusOr> HloInstruction::CreateFromProto( const HloInstructionProto& proto, - const tensorflow::gtl::FlatMap& instruction_map, - const tensorflow::gtl::FlatMap& computation_map) { + const absl::flat_hash_map& instruction_map, + const absl::flat_hash_map& computation_map) { TF_RET_CHECK(!proto.opcode().empty()); TF_ASSIGN_OR_RETURN(HloOpcode opcode, StringToHloOpcode(proto.opcode())); TF_RET_CHECK(proto.has_shape()); @@ -80,6 +81,20 @@ StatusOr> HloInstruction::CreateFromProto( const auto computations = [&computation_map, &proto](int index) { return computation_map.at(proto.called_computation_ids(index)); }; + + TF_RET_CHECK(std::all_of( + proto.operand_ids().begin(), proto.operand_ids().end(), + [&instruction_map](int64 id) { return instruction_map.contains(id); })) + << proto.name() << " instruction contains invalid operand id(s)"; + + TF_RET_CHECK(std::all_of( + proto.called_computation_ids().begin(), + proto.called_computation_ids().end(), + [&computation_map](int64 id) { return computation_map.contains(id); })) + << proto.name() << " instruction references invalid computation id(s)"; + + TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(proto.shape())); + switch (opcode) { // Ops migrated to subclasses. case HloOpcode::kBatchNormTraining: @@ -113,7 +128,7 @@ StatusOr> HloInstruction::CreateFromProto( std::vector fft_length(proto.fft_length().begin(), proto.fft_length().end()); instruction = CreateFft(proto.shape(), operands(0), proto.fft_type(), - tensorflow::gtl::ArraySlice(fft_length)); + absl::Span(fft_length)); break; } case HloOpcode::kSend: @@ -158,29 +173,38 @@ StatusOr> HloInstruction::CreateFromProto( CreateConcatenate(proto.shape(), all_operands(), proto.dimensions(0)); break; case HloOpcode::kReduce: - TF_RET_CHECK(proto.operand_ids_size() == 2) - << "Reduce instruction should have 2 operands but sees " + TF_RET_CHECK(proto.operand_ids_size() % 2 == 0) + << "Reduce instruction should have an even number of operands but " + "sees " << proto.operand_ids_size(); TF_RET_CHECK(proto.called_computation_ids_size() == 1) << "Reduce instruction should have 1 called computation but sees " << proto.called_computation_ids_size(); - instruction = CreateReduce(proto.shape(), operands(0), operands(1), - std::vector(proto.dimensions().begin(), - proto.dimensions().end()), - computations(0)); + { + const auto reduce_operands = all_operands(); + auto inputs = absl::MakeSpan(reduce_operands) + .subspan(0, reduce_operands.size() / 2); + auto init_values = + absl::MakeSpan(reduce_operands) + .subspan(reduce_operands.size() / 2, reduce_operands.size()); + instruction = + CreateReduce(proto.shape(), inputs, init_values, + std::vector(proto.dimensions().begin(), + proto.dimensions().end()), + computations(0)); + } break; case HloOpcode::kSort: { - TF_RET_CHECK(proto.operand_ids_size() == 1 || - proto.operand_ids_size() == 2) - << "Sort instruction should have 1 or 2 operands but has " + TF_RET_CHECK(proto.operand_ids_size() >= 1) + << "Sort instruction should have at least 1 operand but has " << proto.operand_ids_size(); TF_RET_CHECK(proto.dimensions().size() == 1) << "Sort instruction should have 1 dimension"; - HloInstruction* keys = operands(0); - HloInstruction* values = - proto.operand_ids_size() == 2 ? operands(1) : nullptr; - instruction = - CreateSort(proto.shape(), proto.dimensions(0), keys, values); + auto sort_operands = all_operands(); + HloInstruction* keys = sort_operands[0]; + instruction = CreateSort( + proto.shape(), proto.dimensions(0), keys, + absl::Span(sort_operands).subspan(1)); break; } case HloOpcode::kTranspose: @@ -240,7 +264,7 @@ StatusOr> HloInstruction::CreateFromProto( TF_RET_CHECK(proto.has_literal()); TF_ASSIGN_OR_RETURN(auto literal, Literal::CreateFromProto(proto.literal())); - instruction = CreateTrace(literal->GetR1U8AsString(), operands(0)); + instruction = CreateTrace(literal.GetR1U8AsString(), operands(0)); break; } case HloOpcode::kFusion: { @@ -256,7 +280,8 @@ StatusOr> HloInstruction::CreateFromProto( << "Expect 1 called computation for fusion instruction but sees " << proto.called_computation_ids_size(); const int64 fusion_id = proto.called_computation_ids(0); - auto* fused_computation = FindPtrOrNull(computation_map, fusion_id); + auto* fused_computation = + tensorflow::gtl::FindPtrOrNull(computation_map, fusion_id); TF_RET_CHECK(fused_computation != nullptr) << "No fusion computation with id " << fusion_id; instruction = CreateFusion(proto.shape(), fusion_kind, all_operands(), @@ -279,6 +304,9 @@ StatusOr> HloInstruction::CreateFromProto( proto.tuple_index()); break; case HloOpcode::kReducePrecision: + TF_RET_CHECK(proto.operand_ids_size() == 1) + << "ReducePrecision instruction should have 1 operand but sees " + << proto.operand_ids_size(); instruction = CreateReducePrecision(proto.shape(), operands(0), proto.exponent_bits(), proto.mantissa_bits()); @@ -286,12 +314,18 @@ StatusOr> HloInstruction::CreateFromProto( case HloOpcode::kInfeed: { const Shape& data_shape = ShapeUtil::GetTupleElementShape(proto.shape(), 0); - TF_RET_CHECK(proto.operand_ids_size() == 1); + TF_RET_CHECK(proto.operand_ids_size() == 1) + << "Infeed instruction should have 1 operand but sees " + << proto.operand_ids_size(); instruction = CreateInfeed(data_shape, operands(0), proto.infeed_config()); } break; case HloOpcode::kOutfeed: - TF_RET_CHECK(proto.operand_ids_size() == 2); + TF_RET_CHECK(proto.operand_ids_size() == 2) + << "Outfeed instruction should have 2 operands but sees " + << proto.operand_ids_size(); + TF_RETURN_IF_ERROR( + ShapeUtil::ValidateShapeWithOptionalLayout(proto.outfeed_shape())); instruction = CreateOutfeed(proto.outfeed_shape(), operands(0), operands(1), proto.outfeed_config()); break; @@ -320,17 +354,35 @@ StatusOr> HloInstruction::CreateFromProto( proto.replica_groups().end())); break; } - case HloOpcode::kConvolution: + case HloOpcode::kCollectivePermute: { + TF_RET_CHECK(proto.operand_ids_size() == 1) + << "CollectivePermute instruction should have 1 operand but sees " + << proto.operand_ids_size(); + std::vector> source_target_pairs( + proto.source_target_pairs_size()); + for (int i = 0; i < source_target_pairs.size(); i++) { + source_target_pairs[i].first = proto.source_target_pairs(i).source(); + source_target_pairs[i].second = proto.source_target_pairs(i).target(); + } + instruction = CreateCollectivePermute(proto.shape(), operands(0), + source_target_pairs); + break; + } + case HloOpcode::kConvolution: { TF_RET_CHECK(proto.operand_ids_size() == 2) << "Convolution instruction should have 2 operands but sees " << proto.operand_ids_size(); TF_RET_CHECK(proto.has_window()); TF_RET_CHECK(proto.has_convolution_dimension_numbers()); + PrecisionConfig precision_config = proto.precision_config(); + precision_config.mutable_operand_precision()->Resize( + proto.operand_ids_size(), PrecisionConfig::DEFAULT); instruction = CreateConvolve( - proto.shape(), operands(0), operands(1), proto.window(), - proto.convolution_dimension_numbers(), - std::max(static_cast(proto.feature_group_count()), 1LL)); + proto.shape(), operands(0), operands(1), + std::max(proto.feature_group_count(), 1), proto.window(), + proto.convolution_dimension_numbers(), precision_config); break; + } case HloOpcode::kReduceWindow: TF_RET_CHECK(proto.operand_ids_size() == 2) << "ReduceWindow instruction should have 2 operands but sees " @@ -353,8 +405,22 @@ StatusOr> HloInstruction::CreateFromProto( operands(1), operands(2), computations(1)); break; case HloOpcode::kCustomCall: - instruction = CreateCustomCall(proto.shape(), all_operands(), - proto.custom_call_target()); + if (proto.constrain_layout()) { + // A proto RepeatedPtrField cannot be converted to a Span (it is a + // vector of pointers essentially) so create a vector of shapes to pass + // in. + std::vector operand_shapes; + for (const Shape& shape : proto.operand_shapes_with_layout()) { + operand_shapes.push_back(shape); + } + instruction = CreateCustomCall( + proto.shape(), all_operands(), proto.custom_call_target(), + operand_shapes, proto.custom_call_opaque()); + } else { + instruction = CreateCustomCall(proto.shape(), all_operands(), + proto.custom_call_target(), + proto.custom_call_opaque()); + } if (proto.has_window()) { static_cast(instruction.get()) ->set_window(proto.window()); @@ -364,6 +430,9 @@ StatusOr> HloInstruction::CreateFromProto( ->set_convolution_dimension_numbers( proto.convolution_dimension_numbers()); } + static_cast(instruction.get()) + ->set_feature_group_count( + std::max(static_cast(proto.feature_group_count()), 1LL)); break; case HloOpcode::kPad: TF_RET_CHECK(proto.operand_ids_size() == 2) @@ -417,41 +486,77 @@ StatusOr> HloInstruction::CreateFromProto( computations(0), *scatter_dimension_numbers); break; } + case HloOpcode::kIota: + TF_RET_CHECK(proto.dimensions_size() == 1) + << "Iota instruction should have 1 dimension but sees " + << proto.dimensions_size(); + instruction = CreateIota(proto.shape(), proto.dimensions(0)); + break; + case HloOpcode::kDot: { + TF_RET_CHECK(proto.has_dot_dimension_numbers()) + << "Dot instruction should have dot_dimension_numbers."; + TF_RET_CHECK(proto.operand_ids_size() == 2) + << "Dot instruction should have 2 operands but sees " + << proto.operand_ids_size(); + PrecisionConfig precision_config = proto.precision_config(); + precision_config.mutable_operand_precision()->Resize( + proto.operand_ids_size(), PrecisionConfig::DEFAULT); + instruction = absl::make_unique( + proto.shape(), operands(0), operands(1), + proto.dot_dimension_numbers(), precision_config); + break; + } + case HloOpcode::kDomain: { + TF_RET_CHECK(proto.operand_ids_size() == 1) + << "Domain instruction should have 1 operands but sees " + << proto.operand_ids_size(); + TF_RET_CHECK(proto.has_domain_entry_sharding()) + << "Domain instruction must domain_entry_sharding"; + TF_RET_CHECK(proto.has_domain_exit_sharding()) + << "Domain instruction must domain_exit_sharding"; + TF_ASSIGN_OR_RETURN( + HloSharding entry_hlo_sharding, + HloSharding::FromProto(proto.domain_entry_sharding())); + TF_ASSIGN_OR_RETURN(HloSharding exit_hlo_sharding, + HloSharding::FromProto(proto.domain_exit_sharding())); + instruction = absl::make_unique( + proto.shape(), operands(0), + absl::make_unique( + std::make_shared(entry_hlo_sharding)), + absl::make_unique( + std::make_shared(exit_hlo_sharding))); + break; + } default: { instruction = absl::WrapUnique(new HloInstruction(opcode, proto.shape())); for (const int64 operand_id : proto.operand_ids()) { - TF_RET_CHECK(ContainsKey(instruction_map, operand_id)) - << "No instruction with id " << operand_id; instruction->AppendOperand(instruction_map.at(operand_id)); } - for (const int64 predecessor_id : proto.control_predecessor_ids()) { - TF_RET_CHECK(ContainsKey(instruction_map, predecessor_id)) - << "No instruction with id " << predecessor_id; - TF_RETURN_IF_ERROR(instruction_map.at(predecessor_id) - ->AddControlDependencyTo(instruction.get())); - } if (instruction->opcode() != HloOpcode::kFusion) { for (const int64 computation_id : proto.called_computation_ids()) { - TF_RET_CHECK(ContainsKey(computation_map, computation_id)) - << "No computation with id " << computation_id; instruction->called_computations_.push_back( computation_map.at(computation_id)); } } + TF_RET_CHECK(!proto.has_precision_config()) + << instruction->opcode() << proto.DebugString(); + TF_RET_CHECK(!proto.has_dot_dimension_numbers()) << instruction->opcode(); break; } } + for (const int64 predecessor_id : proto.control_predecessor_ids()) { + TF_RET_CHECK(ContainsKey(instruction_map, predecessor_id)) + << "No instruction with id " << predecessor_id; + TF_RETURN_IF_ERROR(instruction_map.at(predecessor_id) + ->AddControlDependencyTo(instruction.get())); + } + TF_RET_CHECK(!proto.name().empty()); instruction->SetAndSanitizeName(proto.name()); instruction->metadata_ = proto.metadata(); instruction->backend_config_ = proto.backend_config(); - instruction->precision_config_ = proto.precision_config(); - - if (proto.has_dot_dimension_numbers()) { - instruction->dot_dimension_numbers_ = - absl::make_unique(proto.dot_dimension_numbers()); - } + instruction->unique_id_ = proto.id(); if (proto.has_sharding()) { TF_ASSIGN_OR_RETURN(const auto& sharding, @@ -474,13 +579,13 @@ StatusOr> HloInstruction::CreateFromProto( } /* static */ std::unique_ptr HloInstruction::CreateConstant( - std::unique_ptr literal) { + Literal literal) { return absl::make_unique(std::move(literal)); } /* static */ std::unique_ptr HloInstruction::CreateIota( - const Shape& shape) { - return absl::WrapUnique(new HloInstruction(HloOpcode::kIota, shape)); + const Shape& shape, int64 iota_dimension) { + return absl::make_unique(shape, iota_dimension); } /* static */ std::unique_ptr @@ -492,13 +597,13 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, /* static */ std::unique_ptr HloInstruction::CreateRng( const Shape& shape, RandomDistribution distribution, - tensorflow::gtl::ArraySlice parameters) { + absl::Span parameters) { return absl::make_unique(shape, distribution, parameters); } /* static */ std::unique_ptr HloInstruction::CreateNary( const Shape& shape, HloOpcode opcode, - tensorflow::gtl::ArraySlice operands) { + absl::Span operands) { if (opcode == HloOpcode::kCopy) { // It is impossible to copy an opaque shape, we don't know how big it is. CHECK(!ShapeUtil::IsOpaque(shape)); @@ -522,7 +627,6 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, case HloOpcode::kCopy: case HloOpcode::kCos: case HloOpcode::kClz: - case HloOpcode::kDomain: case HloOpcode::kExp: case HloOpcode::kExpm1: case HloOpcode::kFloor: @@ -554,7 +658,6 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, case HloOpcode::kAtan2: case HloOpcode::kDivide: case HloOpcode::kComplex: - case HloOpcode::kDot: case HloOpcode::kEq: case HloOpcode::kGe: case HloOpcode::kGt: @@ -600,58 +703,40 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, /* static */ std::unique_ptr HloInstruction::CreateVariadic( const Shape& shape, HloOpcode opcode, - tensorflow::gtl::ArraySlice operands) { + absl::Span operands) { CHECK_EQ(HloOpcode::kTuple, opcode); return CreateNary(shape, opcode, operands); } /* static */ std::unique_ptr HloInstruction::CreateMap( - const Shape& shape, tensorflow::gtl::ArraySlice operands, + const Shape& shape, absl::Span operands, HloComputation* map_computation) { return absl::make_unique(shape, operands, map_computation); } /* static */ std::unique_ptr HloInstruction::CreateConvolve( const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, - const Window& window, const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count) { + int64 feature_group_count, const Window& window, + const ConvolutionDimensionNumbers& dimension_numbers, + const PrecisionConfig& precision_config) { return absl::make_unique( - shape, lhs, rhs, window, dimension_numbers, feature_group_count); + shape, lhs, rhs, feature_group_count, window, dimension_numbers, + precision_config); } /* static */ std::unique_ptr HloInstruction::CreateFft( const Shape& shape, HloInstruction* operand, FftType fft_type, - tensorflow::gtl::ArraySlice fft_length) { + absl::Span fft_length) { return absl::make_unique(shape, operand, fft_type, fft_length); } /* static */ std::unique_ptr HloInstruction::CreateDot( const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, - const DotDimensionNumbers& dimension_numbers) { - auto instruction = - absl::WrapUnique(new HloInstruction(HloOpcode::kDot, shape)); - instruction->AppendOperand(lhs); - instruction->AppendOperand(rhs); - instruction->dot_dimension_numbers_ = - absl::make_unique(dimension_numbers); - return instruction; -} - -/* static */ std::unique_ptr HloInstruction::CreateCanonicalDot( - const Shape& shape, HloInstruction* lhs, HloInstruction* rhs) { - CHECK_EQ(ShapeUtil::Rank(lhs->shape()), 2); - CHECK_EQ(ShapeUtil::Rank(rhs->shape()), 2); - - auto instruction = - absl::WrapUnique(new HloInstruction(HloOpcode::kDot, shape)); - instruction->AppendOperand(lhs); - instruction->AppendOperand(rhs); - instruction->dot_dimension_numbers_ = - absl::make_unique(); - instruction->dot_dimension_numbers_->add_lhs_contracting_dimensions(1); - instruction->dot_dimension_numbers_->add_rhs_contracting_dimensions(0); - return instruction; + const DotDimensionNumbers& dimension_numbers, + const PrecisionConfig& precision_config) { + return absl::make_unique( + shape, lhs, rhs, dimension_numbers, precision_config); } /* static */ std::unique_ptr @@ -665,7 +750,7 @@ HloInstruction::CreateReducePrecision(const Shape& shape, /* static */ std::unique_ptr HloInstruction::CreateCrossReplicaSum( - const Shape& shape, tensorflow::gtl::ArraySlice operands, + const Shape& shape, absl::Span operands, HloComputation* reduce_computation, const std::vector& replica_groups, absl::string_view barrier, const absl::optional& all_reduce_id) { @@ -675,12 +760,20 @@ HloInstruction::CreateCrossReplicaSum( } /* static */ std::unique_ptr HloInstruction::CreateAllToAll( - const Shape& shape, tensorflow::gtl::ArraySlice operands, + const Shape& shape, absl::Span operands, const std::vector& replica_groups) { return absl::make_unique(shape, operands, replica_groups); } +/* static */ std::unique_ptr +HloInstruction::CreateCollectivePermute( + const Shape& shape, HloInstruction* operand, + const std::vector>& source_target_pairs) { + return absl::make_unique( + shape, operand, source_target_pairs); +} + /* static */ std::unique_ptr HloInstruction::CreateInfeed( const Shape& infeed_shape, HloInstruction* token_operand, const string& config) { @@ -729,12 +822,12 @@ HloInstruction::CreateCrossReplicaSum( /* static */ std::unique_ptr HloInstruction::CreateReverse( const Shape& shape, HloInstruction* operand, - tensorflow::gtl::ArraySlice dimensions) { + absl::Span dimensions) { return absl::make_unique(shape, operand, dimensions); } /* static */ std::unique_ptr HloInstruction::CreateAfterAll( - tensorflow::gtl::ArraySlice operands) { + absl::Span operands) { CHECK(!operands.empty()); auto instruction = absl::WrapUnique( new HloInstruction(HloOpcode::kAfterAll, ShapeUtil::MakeTokenShape())); @@ -780,16 +873,15 @@ HloInstruction::CreateCrossReplicaSum( /* static */ std::unique_ptr HloInstruction::CreateSlice( const Shape& shape, HloInstruction* operand, - tensorflow::gtl::ArraySlice start_indices, - tensorflow::gtl::ArraySlice limit_indices, - tensorflow::gtl::ArraySlice strides) { + absl::Span start_indices, + absl::Span limit_indices, absl::Span strides) { return absl::make_unique(shape, operand, start_indices, limit_indices, strides); } /* static */ std::unique_ptr HloInstruction::CreateDynamicSlice( const Shape& shape, HloInstruction* operand, HloInstruction* start_indices, - tensorflow::gtl::ArraySlice slice_sizes) { + absl::Span slice_sizes) { return absl::make_unique( shape, operand, start_indices, slice_sizes); } @@ -808,7 +900,7 @@ HloInstruction::CreateDynamicUpdateSlice(const Shape& shape, } /* static */ std::unique_ptr HloInstruction::CreateConcatenate( - const Shape& shape, tensorflow::gtl::ArraySlice operands, + const Shape& shape, absl::Span operands, int64 dimension) { return absl::make_unique(shape, operands, dimension); @@ -833,7 +925,7 @@ HloInstruction::CreateBitcastConvert(const Shape& shape, /* static */ std::unique_ptr HloInstruction::CreateReduce( const Shape& shape, HloInstruction* operand, HloInstruction* init_value, - tensorflow::gtl::ArraySlice dimensions_to_reduce, + absl::Span dimensions_to_reduce, HloComputation* reduce_computation) { auto instruction = absl::WrapUnique(new HloReduceInstruction( shape, {operand, init_value}, dimensions_to_reduce, reduce_computation)); @@ -841,9 +933,9 @@ HloInstruction::CreateBitcastConvert(const Shape& shape, } /* static */ std::unique_ptr HloInstruction::CreateReduce( - const Shape& shape, tensorflow::gtl::ArraySlice operands, - tensorflow::gtl::ArraySlice init_values, - tensorflow::gtl::ArraySlice dimensions_to_reduce, + const Shape& shape, absl::Span operands, + absl::Span init_values, + absl::Span dimensions_to_reduce, HloComputation* reduce_computation) { std::vector all_args; all_args.reserve(operands.size() * 2); @@ -901,7 +993,7 @@ HloInstruction::CreateSelectAndScatter( /* static */ std::unique_ptr HloInstruction::CreateBroadcast( const Shape& shape, HloInstruction* operand, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return absl::make_unique(shape, operand, broadcast_dimensions); } @@ -979,13 +1071,13 @@ HloInstruction::CreateBroadcastSequence( /* static */ std::unique_ptr HloInstruction::CreateTranspose( const Shape& shape, HloInstruction* operand, - tensorflow::gtl::ArraySlice dimensions) { + absl::Span dimensions) { return absl::make_unique(shape, operand, dimensions); } /* static */ std::unique_ptr HloInstruction::CreateSort( const Shape& shape, int64 dimension, HloInstruction* keys, - HloInstruction* values) { + absl::Span values) { return absl::make_unique(shape, dimension, keys, values); } @@ -997,7 +1089,7 @@ HloInstruction::CreateBroadcastSequence( /* static */ std::unique_ptr HloInstruction::CreateFusion( const Shape& shape, FusionKind fusion_kind, - tensorflow::gtl::ArraySlice operands, + absl::Span operands, HloComputation* fusion_computation) { return absl::make_unique(shape, fusion_kind, operands, fusion_computation); @@ -1020,7 +1112,6 @@ void HloInstruction::SetupDerivedInstruction( derived_instruction->clear_sharding(); } derived_instruction->set_metadata(metadata_); - derived_instruction->set_precision_config(precision_config_); } bool HloInstruction::HasSideEffectNoRecurse() const { @@ -1055,7 +1146,7 @@ bool HloInstruction::HasSideEffect() const { } /* static */ std::unique_ptr HloInstruction::CreateCall( - const Shape& shape, tensorflow::gtl::ArraySlice operands, + const Shape& shape, absl::Span operands, HloComputation* computation) { std::unique_ptr instruction = absl::WrapUnique(new HloInstruction(HloOpcode::kCall, shape)); @@ -1067,14 +1158,23 @@ bool HloInstruction::HasSideEffect() const { } /* static */ std::unique_ptr HloInstruction::CreateCustomCall( - const Shape& shape, tensorflow::gtl::ArraySlice operands, - absl::string_view custom_call_target) { - return absl::make_unique(shape, operands, - custom_call_target); + const Shape& shape, absl::Span operands, + absl::string_view custom_call_target, absl::string_view opaque) { + return absl::make_unique( + shape, operands, custom_call_target, opaque); +} + +/* static */ std::unique_ptr HloInstruction::CreateCustomCall( + const Shape& shape, absl::Span operands, + absl::string_view custom_call_target, + absl::Span operand_shapes_with_layout, + absl::string_view opaque) { + return absl::make_unique( + shape, operands, custom_call_target, opaque, operand_shapes_with_layout); } /* static */ std::unique_ptr HloInstruction::CreateTuple( - tensorflow::gtl::ArraySlice elements) { + absl::Span elements) { std::vector element_shapes; for (auto element : elements) { element_shapes.push_back(element->shape()); @@ -1086,7 +1186,7 @@ bool HloInstruction::HasSideEffect() const { /* static */ std::unique_ptr HloInstruction::CreateGather( const Shape& shape, HloInstruction* operand, HloInstruction* start_indices, const GatherDimensionNumbers& gather_dim_numbers, - tensorflow::gtl::ArraySlice slice_sizes) { + absl::Span slice_sizes) { return absl::make_unique( shape, operand, start_indices, gather_dim_numbers, slice_sizes); } @@ -1105,17 +1205,13 @@ bool HloInstruction::HasSideEffect() const { const Shape& shape, HloInstruction* operand, std::unique_ptr operand_side_metadata, std::unique_ptr user_side_metadata) { - auto instruction = - absl::WrapUnique(new HloInstruction(HloOpcode::kDomain, shape)); - instruction->operand_side_metadata_ = std::move(operand_side_metadata); - instruction->user_side_metadata_ = std::move(user_side_metadata); - instruction->AppendOperand(operand); - return instruction; + return absl::make_unique( + shape, operand, std::move(operand_side_metadata), + std::move(user_side_metadata)); } std::unique_ptr HloInstruction::CloneWithNewOperands( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { VLOG(3) << "CloneWithNewOperands:\n " << ToString(); VLOG(3) << " new operands:"; @@ -1154,6 +1250,7 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( case HloOpcode::kReducePrecision: case HloOpcode::kCrossReplicaSum: case HloOpcode::kAllToAll: + case HloOpcode::kCollectivePermute: case HloOpcode::kInfeed: case HloOpcode::kOutfeed: case HloOpcode::kConvolution: @@ -1166,6 +1263,8 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( case HloOpcode::kGather: case HloOpcode::kScatter: case HloOpcode::kIota: + case HloOpcode::kDot: + case HloOpcode::kDomain: clone = CloneWithNewOperandsImpl(shape, new_operands, context); break; // Unary ops. @@ -1238,11 +1337,6 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( CHECK_EQ(new_operands.size(), 1); clone = CreateBitcastConvert(shape, new_operands[0]); break; - case HloOpcode::kDot: - CHECK_EQ(new_operands.size(), 2); - clone = CreateDot(shape, new_operands[0], new_operands[1], - *dot_dimension_numbers_); - break; case HloOpcode::kReshape: CHECK_EQ(new_operands.size(), 1); clone = CreateReshape(shape, new_operands[0]); @@ -1267,12 +1361,6 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( true_computation(), new_operands[2], false_computation()); break; - case HloOpcode::kDomain: - CHECK_EQ(new_operands.size(), 1); - clone = - CreateDomain(shape, new_operands[0], operand_side_metadata_->Clone(), - user_side_metadata_->Clone()); - break; case HloOpcode::kAfterAll: if (new_operands.empty()) { clone = CreateToken(); @@ -1403,7 +1491,7 @@ int64 HloInstruction::operand_index(const HloInstruction* target) const { HloInstruction::InstructionVector HloInstruction::unique_operands() const { InstructionVector unique; - tensorflow::gtl::FlatSet seen; + absl::flat_hash_set seen; for (HloInstruction* operand : operands()) { if (seen.insert(operand).second) { unique.push_back(operand); @@ -1465,7 +1553,7 @@ void HloInstruction::AppendOperand(HloInstruction* operand) { } void HloInstruction::RemoveOperandsAtAscendingIndices( - tensorflow::gtl::ArraySlice ascending_indices) { + absl::Span ascending_indices) { if (ascending_indices.empty()) { return; } @@ -1568,11 +1656,6 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kAfterAll: return false; - // Check dot dimension numbers. - case HloOpcode::kDot: - return protobuf_util::ProtobufEquals(dot_dimension_numbers(), - other.dot_dimension_numbers()); - // Remaining instructions with special values. case HloOpcode::kCall: return eq_computations(to_apply(), other.to_apply()); @@ -1588,10 +1671,6 @@ bool HloInstruction::IdenticalSlowPath( return false; } - case HloOpcode::kDomain: - return operand_side_metadata().Matches(other.operand_side_metadata()) && - user_side_metadata().Matches(other.user_side_metadata()); - // Ops migrated to subclasses should never come to this line. // TODO(b/80131774): Remove this switch when migration is complete. case HloOpcode::kBatchNormTraining: @@ -1622,6 +1701,7 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kOutfeed: case HloOpcode::kCrossReplicaSum: case HloOpcode::kAllToAll: + case HloOpcode::kCollectivePermute: case HloOpcode::kConvolution: case HloOpcode::kCustomCall: case HloOpcode::kReduceWindow: @@ -1630,6 +1710,8 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kDynamicSlice: case HloOpcode::kGather: case HloOpcode::kScatter: + case HloOpcode::kDot: + case HloOpcode::kDomain: LOG(FATAL) << "Base class impl called for opcode with subclass: " << opcode(); } @@ -1960,7 +2042,7 @@ string HloInstruction::OperandsToStringWithCanonicalNameMap( const HloPrintOptions& options, CanonicalNameMap* canonical_name_map) const { string operands; - tensorflow::gtl::ArraySlice slice(operands_); + absl::Span slice(operands_); const int64 kMaxOperandsToShowIfCompact = 4; if (options.compact_operands() && slice.size() > kMaxOperandsToShowIfCompact) { @@ -1983,7 +2065,7 @@ string HloInstruction::OperandsToStringWithCanonicalNameMap( options.is_in_nested_computation()) { str.push_back(PrintName( canonical_name_map->LookupOrInsert(operand->name()), options)); - } else if (!options.compact_operands()) { + } else if (options.print_operand_names()) { str.push_back(PrintName(operand->name(), options)); } StrAppend(out, StrJoin(str, " ")); @@ -1999,15 +2081,6 @@ std::vector HloInstruction::ExtraAttributesToString( const HloPrintOptions& options) const { std::vector extra = ExtraAttributesToStringImpl(options); - if (dot_dimension_numbers_ != nullptr) { - extra.push_back(DotDimensionNumbersToString()); - } - - string precision_config_string = PrecisionConfigToString(); - if (!precision_config_string.empty()) { - extra.push_back(precision_config_string); - } - if (options.print_subcomputation_mode() == HloPrintOptions::PrintSubcomputationMode::kNameOnly) { if (opcode() == HloOpcode::kWhile) { @@ -2032,13 +2105,12 @@ std::vector HloInstruction::ExtraAttributesToString( extra.push_back( StrCat("to_apply=", PrintName(to_apply()->name(), options))); } else if (!called_computations().empty()) { - extra.push_back( - StrCat("calls=", - StrJoin(called_computations(), ", ", - [&](string* out, const HloComputation* computation) { - StrAppend(out, - PrintName(computation->name(), options)); - }))); + extra.push_back(StrCat( + "calls=", + StrJoin(called_computations(), ", ", + [&](string* out, const HloComputation* computation) { + StrAppend(out, PrintName(computation->name(), options)); + }))); } } else if (options.print_subcomputation_mode() == HloPrintOptions::PrintSubcomputationMode::kFullBodies) { @@ -2085,7 +2157,7 @@ std::vector HloInstruction::ExtraAttributesToString( if (has_sharding()) { extra.push_back(StrCat("sharding=", sharding().ToString())); } - if (!control_predecessors_.empty()) { + if (options.print_control_dependencies() && !control_predecessors_.empty()) { extra.push_back(StrCat("control-predecessors={", StrJoin(control_predecessors_, ", ", [&](string* out, HloInstruction* pre) { @@ -2094,11 +2166,6 @@ std::vector HloInstruction::ExtraAttributesToString( }), "}")); } - if (operand_side_metadata_ != nullptr && user_side_metadata_ != nullptr) { - extra.push_back(StrCat("domain={kind=\"", operand_side_metadata_->Kind(), - "\", entry=", user_side_metadata_->ToString(), - ", exit=", operand_side_metadata_->ToString(), "}")); - } return extra; } @@ -2130,17 +2197,12 @@ HloInstructionProto HloInstruction::ToProto() const { *proto.mutable_metadata() = metadata_; proto.set_backend_config(backend_config_); - *proto.mutable_precision_config() = precision_config_; if (opcode() != HloOpcode::kFusion) { for (const HloComputation* computation : called_computations_) { proto.add_called_computation_ids(computation->unique_id()); } } - if (dot_dimension_numbers_ != nullptr) { - *proto.mutable_dot_dimension_numbers() = *dot_dimension_numbers_; - } - if (has_sharding()) { *proto.mutable_sharding() = sharding().ToProto(); } @@ -2169,7 +2231,7 @@ void HloInstruction::set_tracing(HloInstruction* trace_instruction) { bool HloInstruction::IsFused() const { return parent_->IsFusionComputation(); } -bool HloInstruction::IsFusable() const { +bool HloInstruction::IsFusible() const { // Instructions which are traced should not be fused. if (tracing()) { return false; @@ -2275,6 +2337,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase* visitor) { return visitor->HandleCrossReplicaSum(this); case HloOpcode::kAllToAll: return visitor->HandleAllToAll(this); + case HloOpcode::kCollectivePermute: + return visitor->HandleCollectivePermute(this); case HloOpcode::kTuple: return visitor->HandleTuple(this); case HloOpcode::kMap: @@ -2381,7 +2445,7 @@ Status HloInstruction::Visit(DfsHloVisitorBase* visitor) { return InternalError( "Unhandled HloOpcode for DfsHloVisitor: %s. This should not happen - " "please file a bug for XLA.", - HloOpcodeString(opcode_).c_str()); + HloOpcodeString(opcode_)); } // Explicit instantiations. @@ -2419,7 +2483,7 @@ template static Status PostOrderDFS(HloInstruction* root, Visitor* visitor, const InternalCompareFunction* operand_order, bool ignore_control_predecessors) { - visitor->ReserveVisitStates(root->GetModule()->NumUniqueInstructionIds()); + visitor->ReserveVisitStates(root->GetModule()->instruction_count()); // dfs_stack holds pairs of unique_id(), HloInstruction*>. // @@ -2464,7 +2528,7 @@ static Status PostOrderDFS(HloInstruction* root, Visitor* visitor, if (!TF_PREDICT_TRUE(PushDFSChild(visitor, &dfs_stack, child))) { return FailedPrecondition( "A cycle is detected while visiting instruction %s", - current_node->ToString().c_str()); + current_node->ToString()); } } @@ -2473,7 +2537,7 @@ static Status PostOrderDFS(HloInstruction* root, Visitor* visitor, if (!TF_PREDICT_TRUE(PushDFSChild(visitor, &dfs_stack, child))) { return FailedPrecondition( "A cycle is detected while visiting instruction %s", - current_node->ToString().c_str()); + current_node->ToString()); } } } @@ -2613,7 +2677,6 @@ Status HloInstruction::AcceptOrdered( } const Shape& HloInstruction::shape() const { - TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape_)); return shape_; } @@ -2656,14 +2719,14 @@ class HloInstruction::FusionReusesParamElements { // the value of this parameter, which would save stack space but not allow us // to finish early if we find a reuse. static UseKind Compute(int64 i, const HloInstruction& hlo) { - tensorflow::gtl::FlatMap memoization_cache; + absl::flat_hash_map memoization_cache; return ComputeInternal(i, hlo, &memoization_cache); } private: static UseKind ComputeInternal( int64 i, const HloInstruction& hlo, - tensorflow::gtl::FlatMap* cache) { + absl::flat_hash_map* cache) { if (auto hlo_param = DynCast(&hlo)) { if (hlo_param->parameter_number() == i) { return UseKind::kUse; @@ -2721,10 +2784,13 @@ HloInstruction::UseKind HloInstruction::OperandElementUse(int64 i) const { case HloOpcode::kTranspose: return UseKind::kUsePermutingElements; case HloOpcode::kPad: - case HloOpcode::kReduce: // Pad reuses the padding value but not the padded array elements. - // Reduce reuses the init value but not the operand array elements. return i > 0 ? UseKind::kReuse : UseKind::kUsePermutingElements; + case HloOpcode::kReduce: + // Reduce reuses the init values but not the operand array elements. + return i >= Cast(this)->input_count() + ? UseKind::kReuse + : UseKind::kUsePermutingElements; case HloOpcode::kFusion: // Uses the memoizing, recursive computation defined above. return FusionReusesParamElements::Compute(i, *fused_expression_root()); @@ -2789,7 +2855,7 @@ StatusOr StringToFusionKind( if (kind_name == "kCustom") { return HloInstruction::FusionKind::kCustom; } - return InvalidArgument("Unknown fusion kind: %s", kind_name.c_str()); + return InvalidArgument("Unknown fusion kind: %s", kind_name); } string PaddingConfigToString(const PaddingConfig& padding) { @@ -2829,8 +2895,8 @@ string RandomDistributionToString(const RandomDistribution& distribution) { return absl::AsciiStrToLower(RandomDistribution_Name(distribution)); } -string PrecisionToString(const PrecisionConfigProto::Precision& precision) { - return absl::AsciiStrToLower(PrecisionConfigProto::Precision_Name(precision)); +string PrecisionToString(const PrecisionConfig::Precision& precision) { + return absl::AsciiStrToLower(PrecisionConfig::Precision_Name(precision)); } string ConvolutionDimensionNumbersToString( @@ -2862,31 +2928,6 @@ string ConvolutionDimensionNumbersToString( StrJoin(output_dims, "")); } -string HloInstruction::DotDimensionNumbersToString() const { - std::vector result; - if (dot_dimension_numbers_ == nullptr) { - return ""; - } - const DotDimensionNumbers& dnums = *dot_dimension_numbers_; - if (!dnums.lhs_batch_dimensions().empty()) { - result.push_back(StrCat("lhs_batch_dims={", - StrJoin(dnums.lhs_batch_dimensions(), ","), "}")); - } - result.push_back(StrCat("lhs_contracting_dims={", - StrJoin(dnums.lhs_contracting_dimensions(), ","), - "}")); - - if (!dnums.rhs_batch_dimensions().empty()) { - result.push_back(StrCat("rhs_batch_dims={", - StrJoin(dnums.rhs_batch_dimensions(), ","), "}")); - } - result.push_back(StrCat("rhs_contracting_dims={", - StrJoin(dnums.rhs_contracting_dimensions(), ","), - "}")); - - return StrJoin(result, ", "); -} - StatusOr StringToRandomDistribution(const string& name) { static std::unordered_map* map = [] { static auto* map = new std::unordered_map; @@ -2905,31 +2946,13 @@ StatusOr StringToRandomDistribution(const string& name) { return found->second; } -string HloInstruction::PrecisionConfigToString() const { - if (precision_config_.operand_precision().empty()) { - return ""; - } - return StrCat( - "operand_precision={", - StrJoin(precision_config_.operand_precision(), ",", - [](string* out, int32 precision) { - CHECK(PrecisionConfigProto::Precision_IsValid(precision)) - << precision; - StrAppend(out, PrecisionToString( - static_cast( - precision))); - }), - "}"); -} - -StatusOr StringToPrecision( - const string& name) { - static std::unordered_map* map = [] { +StatusOr StringToPrecision(const string& name) { + static std::unordered_map* map = [] { static auto* map = - new std::unordered_map; - for (int i = 0; i < PrecisionConfigProto::Precision_ARRAYSIZE; i++) { - if (PrecisionConfigProto::Precision_IsValid(i)) { - auto value = static_cast(i); + new std::unordered_map; + for (int i = 0; i < PrecisionConfig::Precision_ARRAYSIZE; i++) { + if (PrecisionConfig::Precision_IsValid(i)) { + auto value = static_cast(i); (*map)[PrecisionToString(value)] = value; } } @@ -2946,6 +2969,26 @@ std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind) { return os << ToString(kind); } +bool HloPtrComparator::operator()(const HloInstruction* const& lhs, + const HloInstruction* const& rhs) const { + if (rhs == nullptr) { + // Nothing compares less than nullptr. + return false; + } + if (lhs == nullptr) { + return true; + } + auto lhs_module = lhs->GetModule(); + auto rhs_module = rhs->GetModule(); + CHECK((lhs_module == nullptr && rhs_module == nullptr) || + (lhs_module != nullptr && rhs_module != nullptr)); + if (lhs_module != nullptr && + lhs_module->unique_id() != rhs_module->unique_id()) { + return lhs_module->unique_id() < rhs_module->unique_id(); + } + return lhs->unique_id() < rhs->unique_id(); +} + bool HloInstruction::CouldBeBitcast() const { switch (opcode_) { case HloOpcode::kTranspose: @@ -2982,6 +3025,16 @@ Status HloInstruction::set_backend_config( return ret; } +const PrecisionConfig& HloInstruction::precision_config() const { + if (auto* convolution = DynCast(this)) { + return convolution->precision_config(); + } + if (auto* dot = DynCast(this)) { + return dot->precision_config(); + } + LOG(FATAL) << "Unimplemented method."; +} + HloModule* HloInstruction::GetModule() const { if (parent_) { return parent_->parent(); @@ -3053,10 +3106,6 @@ const std::vector& HloInstruction::slice_strides() const { return Cast(this)->slice_strides(); } -bool HloInstruction::IsInPlaceSlice() const { - return Cast(this)->IsInPlaceSlice(); -} - const Literal& HloInstruction::literal() const { return Cast(this)->literal(); } @@ -3189,13 +3238,18 @@ const std::vector& HloInstruction::replica_groups() const { return Cast(this)->replica_groups(); } +const std::vector>& +HloInstruction::source_target_pairs() const { + return Cast(this)->source_target_pairs(); +} + string HloInstruction::cross_replica_sum_barrier() const { - return Cast(this)->cross_replica_sum_barrier(); + return Cast(this)->cross_replica_sum_barrier(); } void HloInstruction::set_cross_replica_sum_barrier(const string& barrier) { - return Cast(this)->set_cross_replica_sum_barrier( - barrier); + return Cast(this)->set_cross_replica_sum_barrier( + barrier); } absl::optional HloInstruction::all_reduce_id() const { @@ -3225,7 +3279,15 @@ void HloInstruction::set_convolution_dimension_numbers( } int64 HloInstruction::feature_group_count() const { - return Cast(this)->feature_group_count(); + if (auto convolution = DynCast(this)) { + return convolution->feature_group_count(); + } + return Cast(this)->feature_group_count(); +} + +void HloInstruction::set_feature_group_count(int64 feature_group_count) { + Cast(this)->set_feature_group_count( + feature_group_count); } HloComputation* HloInstruction::select() const { @@ -3264,7 +3326,7 @@ const GatherDimensionNumbers& HloInstruction::gather_dimension_numbers() const { return Cast(this)->gather_dimension_numbers(); } -tensorflow::gtl::ArraySlice HloInstruction::gather_slice_sizes() const { +absl::Span HloInstruction::gather_slice_sizes() const { return Cast(this)->gather_slice_sizes(); } @@ -3273,4 +3335,15 @@ const ScatterDimensionNumbers& HloInstruction::scatter_dimension_numbers() return Cast(this)->scatter_dimension_numbers(); } +const DotDimensionNumbers& HloInstruction::dot_dimension_numbers() const { + return Cast(this)->dot_dimension_numbers(); +} + +const DomainMetadata& HloInstruction::operand_side_metadata() const { + return Cast(this)->operand_side_metadata(); +} + +const DomainMetadata& HloInstruction::user_side_metadata() const { + return Cast(this)->user_side_metadata(); +} } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 948e33a0a3520593d681223189ef852587e5934b..15a4da8dbe0053aad314989a6718ebd61532ab8b 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -28,14 +28,15 @@ limitations under the License. #include #include #include -#include -#include #include +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/iterator_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/map_util.h" @@ -49,8 +50,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/gtl/iterator_range.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" @@ -80,8 +79,10 @@ class HloPrintOptions { print_backend_config_(true), compact_operands_(false), print_operand_shape_(true), + print_operand_names_(true), print_program_shape_(true), print_percent_(true), + print_control_dependencies_(true), canonicalize_instruction_names_(false), indent_amount_(0), is_in_nested_computation_(false) {} @@ -94,7 +95,8 @@ class HloPrintOptions { .set_print_backend_config(false) .set_print_operand_shape(false) .set_print_program_shape(false) - .set_print_percent(false); + .set_print_percent(false) + .set_print_control_dependencies(false); } // Options to produce the canonical string representing an isomorphic @@ -105,9 +107,11 @@ class HloPrintOptions { .set_print_metadata(false) .set_print_backend_config(false) .set_compact_operands(true) + .set_print_operand_names(false) .set_print_operand_shape(true) .set_print_program_shape(false) .set_print_percent(false) + .set_print_control_dependencies(false) .set_canonicalize_instruction_names(true); } @@ -141,6 +145,12 @@ class HloPrintOptions { return *this; } + // If true, the operand names will be printed. + HloPrintOptions& set_print_operand_names(bool value) { + print_operand_names_ = value; + return *this; + } + // If true, program shape of hlo computations will be printed. HloPrintOptions& set_print_program_shape(bool value) { print_program_shape_ = value; @@ -153,8 +163,14 @@ class HloPrintOptions { return *this; } - // If true, only a part of operands will be printed out, and their names will - // be omitted (note that in this case the text will not be parsable). + // If true, control dependencies will be printed. + HloPrintOptions& set_print_control_dependencies(bool value) { + print_control_dependencies_ = value; + return *this; + } + + // If true, only a part of operands will be printed out (note that in this + // case the text will not be parsable). HloPrintOptions& set_compact_operands(bool value) { compact_operands_ = value; return *this; @@ -188,8 +204,12 @@ class HloPrintOptions { bool print_backend_config() const { return print_backend_config_; } bool compact_operands() const { return compact_operands_; } bool print_operand_shape() const { return print_operand_shape_; } + bool print_operand_names() const { return print_operand_names_; } bool print_program_shape() const { return print_program_shape_; } bool print_percent() const { return print_percent_; } + bool print_control_dependencies() const { + return print_control_dependencies_; + } bool canonicalize_instruction_names() const { return canonicalize_instruction_names_; } @@ -203,8 +223,10 @@ class HloPrintOptions { bool print_backend_config_; bool compact_operands_; bool print_operand_shape_; + bool print_operand_names_; bool print_program_shape_; bool print_percent_; + bool print_control_dependencies_; bool canonicalize_instruction_names_; int indent_amount_; bool is_in_nested_computation_; @@ -234,7 +256,7 @@ class CanonicalNameMap { private: int64 index; - tensorflow::gtl::FlatMap canonical_name_map; + absl::flat_hash_map canonical_name_map; }; // HLO instructions are the atomic unit of the high-level compiler's IR. @@ -337,8 +359,8 @@ class HloInstruction { // calls. static StatusOr> CreateFromProto( const HloInstructionProto& proto, - const tensorflow::gtl::FlatMap& instruction_map, - const tensorflow::gtl::FlatMap& computation_map); + const absl::flat_hash_map& instruction_map, + const absl::flat_hash_map& computation_map); // Creates a parameter-retrieving instruction. static std::unique_ptr CreateParameter(int64 parameter_number, @@ -346,11 +368,11 @@ class HloInstruction { const string& name); // Creates a literal constant instruction. - static std::unique_ptr CreateConstant( - std::unique_ptr literal); + static std::unique_ptr CreateConstant(Literal literal); // Creates an Iota instruction. - static std::unique_ptr CreateIota(const Shape& shape); + static std::unique_ptr CreateIota(const Shape& shape, + int64 iota_dimension); // Creates a get tuple element instruction. static std::unique_ptr CreateGetTupleElement( @@ -364,7 +386,7 @@ class HloInstruction { // random numbers from a given distribution. static std::unique_ptr CreateRng( const Shape& shape, RandomDistribution distribution, - tensorflow::gtl::ArraySlice parameters); + absl::Span parameters); // Creates a unary instruction (one operand). // Precondition: opcode must be a legitimate unary operation. @@ -391,39 +413,34 @@ class HloInstruction { // Precondition: opcode must be a legitimate variadic operation. static std::unique_ptr CreateVariadic( const Shape& shape, HloOpcode opcode, - tensorflow::gtl::ArraySlice operands); + absl::Span operands); // Creates a map instruction, where the computation (given by the handle) is // applied element-wise to every element in operands (across the operands, // at a given index) static std::unique_ptr CreateMap( - const Shape& shape, tensorflow::gtl::ArraySlice operands, + const Shape& shape, absl::Span operands, HloComputation* map_computation); // Creates a convolution op, where rhs is the convolutional filter // and window describes how the filter is applied to lhs. static std::unique_ptr CreateConvolve( const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, - const Window& window, + int64 feature_group_count, const Window& window, const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count = 1); + const PrecisionConfig& precision_config); // Creates an FFT op, of the type indicated by fft_type. static std::unique_ptr CreateFft( const Shape& shape, HloInstruction* operand, FftType fft_type, - tensorflow::gtl::ArraySlice fft_length); + absl::Span fft_length); // Creates a dot op with operands 'lhs' and 'rhs' with contracting and batch // dimensions specified in 'dimension_numbers'. static std::unique_ptr CreateDot( const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, - const DotDimensionNumbers& dimension_numbers); - - // Creates a dot op with operands 'lhs' and 'rhs' that contracts dimension 1 - // of the LHS with dimension 0 of the RHS with no batch dimensions. Both LHS - // and the RHS must be of rank 2. - static std::unique_ptr CreateCanonicalDot( - const Shape& shape, HloInstruction* lhs, HloInstruction* rhs); + const DotDimensionNumbers& dimension_numbers, + const PrecisionConfig& precision_config); // Creates a reduce-precision op, where operand is the data to reduce in // precision, and exponent_bits and mantissa_bits describe the precision to @@ -446,9 +463,9 @@ class HloInstruction { // the same all_reduce_id, they will be 'Allreduce'd. If empty, Allreduce will // not be applied cross modules. // - // TODO(b/79737069): Rename this to AllReduce. + // TODO(b/117564385): Rename this to AllReduce. static std::unique_ptr CreateCrossReplicaSum( - const Shape& shape, tensorflow::gtl::ArraySlice operands, + const Shape& shape, absl::Span operands, HloComputation* reduce_computation, const std::vector& replica_groups, absl::string_view barrier, const absl::optional& all_reduce_id); @@ -467,9 +484,18 @@ class HloInstruction { // be concatenated in the order of 1, 2, 3; another Alltoall will be applied // within replica 4, 5, 0, and the concatenation order is 4, 5, 0. static std::unique_ptr CreateAllToAll( - const Shape& shape, tensorflow::gtl::ArraySlice operands, + const Shape& shape, absl::Span operands, const std::vector& replica_groups); + // Creates a communitation instructions that permutes data cross replicas. + // Data is sent/received according to the (source_replica_id, + // target_replica_id) pairs in `source_target_pairs`. If a replica id is not a + // target_replica_id in any pair, the output on that replica is a tensor + // conssits of 0(s) in `shape`. + static std::unique_ptr CreateCollectivePermute( + const Shape& shape, HloInstruction* operand, + const std::vector>& source_target_pairs); + // Creates a conversion instruction, where operand is the data to convert and // shape is the target shape for the conversion. static std::unique_ptr CreateConvert(const Shape& shape, @@ -526,17 +552,15 @@ class HloInstruction { // start/limit indices. static std::unique_ptr CreateSlice( const Shape& shape, HloInstruction* operand, - tensorflow::gtl::ArraySlice start_indices, - tensorflow::gtl::ArraySlice limit_indices, - tensorflow::gtl::ArraySlice strides); + absl::Span start_indices, + absl::Span limit_indices, absl::Span strides); // Creates a slice instruction, where the first operand is sliced by // start indices specified in the second operand, and by size specified in // 'slice_sizes'. static std::unique_ptr CreateDynamicSlice( const Shape& shape, HloInstruction* operand, - HloInstruction* start_indices, - tensorflow::gtl::ArraySlice slice_sizes); + HloInstruction* start_indices, absl::Span slice_sizes); // Creates a dynamic update slice instruction, which updates a slice // of 'operand' with 'update' and 'start_indices'. @@ -547,7 +571,7 @@ class HloInstruction { // Creates a concatenate instruction, where the operands are concatenated on // the provided dimension. static std::unique_ptr CreateConcatenate( - const Shape& shape, tensorflow::gtl::ArraySlice operands, + const Shape& shape, absl::Span operands, int64 dimension); // Creates a reduce instruction, where the computation (given by the handle) @@ -559,7 +583,7 @@ class HloInstruction { // f(f(init, value0), value1), ...) static std::unique_ptr CreateReduce( const Shape& shape, HloInstruction* operand, HloInstruction* init_value, - tensorflow::gtl::ArraySlice dimensions_to_reduce, + absl::Span dimensions_to_reduce, HloComputation* reduce_computation); // A more general, multiple-argument version of the above. @@ -574,9 +598,9 @@ class HloInstruction { // ... // TODO(b/112040122): Add support to this in HLO passes and in backends. static std::unique_ptr CreateReduce( - const Shape& shape, tensorflow::gtl::ArraySlice operands, - tensorflow::gtl::ArraySlice init_values, - tensorflow::gtl::ArraySlice dimensions_to_reduce, + const Shape& shape, absl::Span operands, + absl::Span init_values, + absl::Span dimensions_to_reduce, HloComputation* reduce_computation); // Creates a reduce-window instruction, where the computation (given @@ -613,7 +637,7 @@ class HloInstruction { // Creates a broadcast instruction. static std::unique_ptr CreateBroadcast( const Shape& shape, HloInstruction* operand, - tensorflow::gtl::ArraySlice broadcast_dimensions); + absl::Span broadcast_dimensions); // Creates a sequence of instructions that performs an explicit broadcast of // the operand to the target shape. @@ -643,12 +667,12 @@ class HloInstruction { // Creates a transpose instruction which permutes the operand dimensions. static std::unique_ptr CreateTranspose( const Shape& shape, HloInstruction* operand, - tensorflow::gtl::ArraySlice dimensions); + absl::Span dimensions); - // Creates a sort op, with a keys operand, and an optional values operand. + // Creates a sort op, with a keys operand, and optional values operands. static std::unique_ptr CreateSort( const Shape& shape, int64 dimension, HloInstruction* keys, - HloInstruction* values = nullptr); + absl::Span values = {}); // Creates a while instruction, given a condition computation, a body // computation, and the initial value for the input of the computations. For @@ -669,7 +693,7 @@ class HloInstruction { const Shape& shape, HloInstruction* operand, HloInstruction* start_indices, const GatherDimensionNumbers& gather_dim_numbers, - tensorflow::gtl::ArraySlice slice_sizes); + absl::Span slice_sizes); static std::unique_ptr CreateScatter( const Shape& shape, HloInstruction* operand, @@ -693,37 +717,48 @@ class HloInstruction { static std::unique_ptr CreateFusion( const Shape& shape, FusionKind fusion_kind, - tensorflow::gtl::ArraySlice operands, + absl::Span operands, HloComputation* fusion_computation); // Creates a call instruction that applies the given computation on the given // operands. "shape" is the resultant shape. static std::unique_ptr CreateCall( - const Shape& shape, tensorflow::gtl::ArraySlice operands, + const Shape& shape, absl::Span operands, HloComputation* computation); // Creates a custom call instruction that applies the given custom call target - // to the given operands. "shape" is the resultant shape. + // to the given operands. "opaque" can be an arbitrary string with a + // backend-specific interpretation. "shape" is the resultant shape. + static std::unique_ptr CreateCustomCall( + const Shape& shape, absl::Span operands, + absl::string_view custom_call_target, absl::string_view opaque = ""); + + // Overload which constrains the layouts of the operand and result. 'shape' + // and 'operand_shapes_with_layout' must have layouts. + // 'operand_shapes_with_layout' must have a compatible element for each + // operand. static std::unique_ptr CreateCustomCall( - const Shape& shape, tensorflow::gtl::ArraySlice operands, - absl::string_view custom_call_target); + const Shape& shape, absl::Span operands, + absl::string_view custom_call_target, + absl::Span operand_shapes_with_layout, + absl::string_view opaque = ""); // Creates a tuple instruction with the given elements. This is a convenience // wrapper around CreateVariadic. static std::unique_ptr CreateTuple( - tensorflow::gtl::ArraySlice elements); + absl::Span elements); // Creates a reverse instruction, which reverses the order of the elements // in the specified dimensions. static std::unique_ptr CreateReverse( const Shape& shape, HloInstruction* operand, - tensorflow::gtl::ArraySlice dimensions); + absl::Span dimensions); // Creates a Afterall instruction used for joining or creating new values of // token type which thread through side-effecting operations. Operands must // all be tokens, and there must be at least one operand. static std::unique_ptr CreateAfterAll( - tensorflow::gtl::ArraySlice operands); + absl::Span operands); // Creates an AfterAll instruction which creates a token type out of thin air // (no operands). This is a separate method from CreateAfterAll to facility @@ -857,11 +892,6 @@ class HloInstruction { return false; } - if (!ContainersEqual(precision_config_.operand_precision(), - other.precision_config_.operand_precision())) { - return false; - } - return IdenticalSlowPath(other, eq_computations); } @@ -1029,7 +1059,7 @@ class HloInstruction { // Returns true if this instruction can be legally fused into a fusion // instruction. - bool IsFusable() const; + bool IsFusible() const; // Returns the sharding applied to this operator. // REQUIRES: has_sharding() is true. @@ -1076,15 +1106,6 @@ class HloInstruction { return other->has_sharding() ? sharding() == other->sharding() : false; } - // Retrieves the operand side metadata of a kDomain instruction. - const DomainMetadata& operand_side_metadata() const { - return *operand_side_metadata_; - } - // Retrieves the user side metadata of a kDomain instruction. - const DomainMetadata& user_side_metadata() const { - return *user_side_metadata_; - } - // When creating a new instruction which either replaces, or shifts up (kCopy // insertion case), another instruction, we need to make sure the certain // properties of the new instruction are copied into the derived one. As of @@ -1092,18 +1113,6 @@ class HloInstruction { // instruction. void SetupDerivedInstruction(HloInstruction* derived_instruction) const; - // Returns data on the dimension numbers used for a dot operation. - const DotDimensionNumbers& dot_dimension_numbers() const { - CHECK(dot_dimension_numbers_ != nullptr); - return *dot_dimension_numbers_; - } - - // Returns the dump string of the dot dimension numbers. - string DotDimensionNumbersToString() const; - - // Returns the dump string of the precision configuration. - string PrecisionConfigToString() const; - // Clones the HLO instruction. The clone will have the same opcode, shape, and // operands. After creation the clone has no uses. "this" (the instruction // cloned from) is not changed. Suffix is the string to append to the name of @@ -1114,8 +1123,7 @@ class HloInstruction { // Clones the HLO instruction as above but with new shape and operands. std::unique_ptr CloneWithNewOperands( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context = nullptr) const; // Returns the computations this instruction directly calls (if any). @@ -1254,12 +1262,8 @@ class HloInstruction { // information. Transformations to other HLOs will not preserve this // information but it is presumed that the alternate lowering is strictly // superior. - const PrecisionConfigProto& precision_config() const { - return precision_config_; - } - void set_precision_config(const PrecisionConfigProto& precision_config) { - precision_config_ = precision_config; - } + // Precondition: opcode must be kConvolution or kDot. + const PrecisionConfig& precision_config() const; // Sets the debug metadata for this instruction. void set_metadata(const OpMetadata& metadata) { metadata_ = metadata; } @@ -1335,9 +1339,6 @@ class HloInstruction { int64 slice_strides(int64 dimension) const; const std::vector& slice_strides() const; - // Delegates to HloSliceInstruction::IsInPlaceSlice. - bool IsInPlaceSlice() const; - // Returns the literal associated with this instruction. const Literal& literal() const; @@ -1429,9 +1430,12 @@ class HloInstruction { // Returns the shape for the Outfeed instruction. const Shape& outfeed_shape() const; - // Delegates to HloAllToAllInstruction::replica_groups. + // Delegates to HloCollectiveInstruction::replica_groups. const std::vector& replica_groups() const; + // Delegates to HloCollectivePermuteInstruction::source_target_pairs. + const std::vector>& source_target_pairs() const; + // Delegates to HloAllReduceInstruction::cross_replica_sum_barrier. string cross_replica_sum_barrier() const; void set_cross_replica_sum_barrier(const string& barrier); @@ -1465,6 +1469,8 @@ class HloInstruction { // dimension and output feature dimension. int64 feature_group_count() const; + void set_feature_group_count(int64 feature_group_count); + // Delegates to HloSelectAndScatterInstruction::select. HloComputation* select() const; @@ -1492,11 +1498,20 @@ class HloInstruction { // Delegates to HloGatherInstruction::gather_dimension_numbers. const GatherDimensionNumbers& gather_dimension_numbers() const; // Delegates to HloGatherInstruction::gather_slice_sizes. - tensorflow::gtl::ArraySlice gather_slice_sizes() const; + absl::Span gather_slice_sizes() const; // Delegates to HloScatterInstruction::scatter_dimension_numbers(). const ScatterDimensionNumbers& scatter_dimension_numbers() const; + // Delegates to HloDotInstruction::dot_dimension_numbers(). + const DotDimensionNumbers& dot_dimension_numbers() const; + + // Delegates to HloDomainInstruction::operand_side_metadata(). + const DomainMetadata& operand_side_metadata() const; + + // Delegates to HloDomainInstruction::user_side_metadata(). + const DomainMetadata& user_side_metadata() const; + // Old methods kept for smooth subclassing transition END. protected: @@ -1518,7 +1533,7 @@ class HloInstruction { // Removes a list of operands with the given indices in ascending order. void RemoveOperandsAtAscendingIndices( - tensorflow::gtl::ArraySlice ascending_indices); + absl::Span ascending_indices); void AppendComputation(HloComputation* computation) { called_computations_.push_back(computation); @@ -1548,8 +1563,7 @@ class HloInstruction { private: // Implementation for non-common logic of CloneWithNewOperands. virtual std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { // TODO(b/80131774): This should be pure virtual. LOG(FATAL) << "Unimplemented method."; @@ -1595,7 +1609,7 @@ class HloInstruction { // Creates an n-ary elementwise operation. static std::unique_ptr CreateNary( const Shape& shape, HloOpcode opcode, - tensorflow::gtl::ArraySlice operands); + absl::Span operands); // Adds a user for this instruction. void AddUser(HloInstruction* user); @@ -1619,6 +1633,10 @@ class HloInstruction { InstructionVector operands_; // The set of control predecessors of this instruction. + // Note that the order of the instructions in the vector influences the order + // computed in HloComputation::ComputeInstructionPostOrder, which may + // influence the result of the compilation by changing the scheduling. We are + // not sure if it matters. std::vector control_predecessors_; // The users of this instruction. Users are HLOs where this instruction is an @@ -1626,7 +1644,7 @@ class HloInstruction { // members. The set enables fast membership testing and the vector enables // fast, stable iteration. std::vector users_; - std::unordered_set user_set_; + absl::flat_hash_set user_set_; // The set of control successors of this instruction. std::vector control_successors_; @@ -1637,22 +1655,12 @@ class HloInstruction { // Result shape of this instruction. Shape shape_; - // Describes the dimension numbers used for a dot. - std::unique_ptr dot_dimension_numbers_; - - // Used to tag kCopy instructions that are eligible for copy elision. - bool copy_elision_allowed_ = true; - // The sharding, if one exists. // Uses std::shared_ptr to allow reuse of the same sharding object between // HloInstructions and other components as HloSharding can be very large for // many element tuples. std::shared_ptr sharding_; - // Fields used by the kDomain instruction. - std::unique_ptr operand_side_metadata_; - std::unique_ptr user_side_metadata_; - // Computations called by this instruction. std::vector called_computations_; @@ -1666,10 +1674,6 @@ class HloInstruction { // HLO. See the documentation on backend_config(). string backend_config_; - // Information used to communicate to the implementation about the algorithm - // used to produce results. See the documentation on precision_config(). - PrecisionConfigProto precision_config_; - // String identifier for instruction. string name_; @@ -1692,12 +1696,12 @@ StatusOr StringToFusionKind( string PaddingConfigToString(const PaddingConfig& padding); string OpMetadataToString(const OpMetadata& metadata); string RandomDistributionToString(const RandomDistribution& distribution); -string PrecisionToString(const PrecisionConfigProto::Precision& precision); +string PrecisionToString(const PrecisionConfig::Precision& precision); string ConvolutionDimensionNumbersToString( const ConvolutionDimensionNumbers& dnums); StatusOr StringToRandomDistribution(const string& name); -StatusOr StringToPrecision(const string& name); +StatusOr StringToPrecision(const string& name); std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind); @@ -1706,21 +1710,9 @@ std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind); // To make the iteration order over the map deterministic, the comparator // should not be using the pointer values, but rather an intrinsic property of // the hlo. Exception: null pointer values compare less than non-null. -// -// Note that this cannot be used for HLO instructions across multiple modules -// since the id of HLO instructions are only unique within each HLO module. struct HloPtrComparator { bool operator()(const HloInstruction* const& lhs, - const HloInstruction* const& rhs) const { - if (rhs == nullptr) { - // Nothing compares less than nullptr. - return false; - } - if (lhs == nullptr) { - return true; - } - return lhs->unique_id() < rhs->unique_id(); - } + const HloInstruction* const& rhs) const; }; template diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc index 504b13043f86f152cc83b0b961bf2e8fa3ad2afb..d93351fe0435b5f29035dc4ea0621a8c576bfd5a 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc @@ -29,7 +29,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/window_util.h" @@ -39,10 +39,8 @@ namespace { using ::testing::ElementsAre; using ::testing::UnorderedElementsAre; -class HloInstructionTest : public HloTestBase { +class HloInstructionTest : public HloVerifiedTestBase { protected: - HloInstructionTest() {} - Shape r0f32_ = ShapeUtil::MakeShape(F32, {}); }; @@ -53,7 +51,7 @@ class OpAndUserCollectingVisitor : public DfsHloVisitorWithDefault { public: Status DefaultAction(HloInstruction* hlo_instruction) override { return Unimplemented("not implemented %s", - HloOpcodeString(hlo_instruction->opcode()).c_str()); + HloOpcodeString(hlo_instruction->opcode())); } Status HandleParameter(HloInstruction* parameter) override { @@ -137,7 +135,8 @@ TEST_F(HloInstructionTest, BasicProperties) { auto parameter = HloInstruction::CreateParameter(1, r0f32_, "foo"); EXPECT_EQ(HloOpcode::kParameter, parameter->opcode()); - EXPECT_TRUE(ShapeUtil::IsScalarF32(parameter->shape())); + EXPECT_TRUE(ShapeUtil::IsScalarWithElementType(parameter->shape(), F32)); + EXPECT_FALSE(ShapeUtil::IsScalarWithElementType(parameter->shape(), S32)); EXPECT_EQ(0, parameter->operand_count()); } @@ -1086,16 +1085,14 @@ TEST_F(HloInstructionTest, PartiallyElementwise) { TEST_F(HloInstructionTest, PartiallyElementwiseWithReuse) { // Fused expression: - // - // x y - // \ / \ - // min broadcast + // y + // / + // x broadcast + // \ / | + // min | // \ / // sub // - // The fusion instruction is elementwise on `x` because the only path from x - // to sub contains only elementwise operations. It is not elementwise on `y` - // because the path y->broadcast->sub is not all elementwise. const Shape r0f32 = ShapeUtil::MakeShape(F32, {}); const Shape r1f32 = ShapeUtil::MakeShape(F32, {5}); @@ -1104,10 +1101,10 @@ TEST_F(HloInstructionTest, PartiallyElementwiseWithReuse) { builder.AddInstruction(HloInstruction::CreateParameter(0, r1f32, "x")); HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32, "y")); - HloInstruction* min = builder.AddInstruction( - HloInstruction::CreateBinary(r1f32, HloOpcode::kMinimum, x, y)); HloInstruction* broadcast = - builder.AddInstruction(HloInstruction::CreateBroadcast(r1f32, y, {0})); + builder.AddInstruction(HloInstruction::CreateBroadcast(r1f32, y, {})); + HloInstruction* min = builder.AddInstruction( + HloInstruction::CreateBinary(r1f32, HloOpcode::kMinimum, x, broadcast)); HloInstruction* sub = builder.AddInstruction(HloInstruction::CreateBinary( r1f32, HloOpcode::kSubtract, min, broadcast)); @@ -1118,10 +1115,10 @@ TEST_F(HloInstructionTest, PartiallyElementwiseWithReuse) { EXPECT_FALSE(fusion->IsElementwise()); for (int64 operand_idx = 0; operand_idx < fusion->operand_count(); ++operand_idx) { - if (fusion->operand(operand_idx) == x) { - EXPECT_TRUE(fusion->IsElementwiseOnOperand(operand_idx)); - } else { + if (fusion->operand(operand_idx) == y) { EXPECT_FALSE(fusion->IsElementwiseOnOperand(operand_idx)); + } else { + EXPECT_TRUE(fusion->IsElementwiseOnOperand(operand_idx)); } } } @@ -1151,8 +1148,8 @@ TEST_F(HloInstructionTest, CloneOfFusionPreservesShape) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - HloInstruction* dot = builder.AddInstruction( - HloInstruction::CreateDot(sout, x, reshape, dot_dnums)); + HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot( + sout, x, reshape, dot_dnums, DefaultPrecisionConfig(2))); auto module = CreateNewModule(); auto* computation = module->AddEntryComputation(builder.Build()); @@ -1192,8 +1189,8 @@ TEST_F(HloInstructionTest, NoRedundantFusionOperandsAfterReplacingUse) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - HloInstruction* dot = builder.AddInstruction( - HloInstruction::CreateDot(s, x, reshape, dot_dnums)); + HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot( + s, x, reshape, dot_dnums, DefaultPrecisionConfig(2))); auto module = CreateNewModule(); auto* computation = module->AddEntryComputation(builder.Build()); @@ -1243,12 +1240,12 @@ TEST_F(HloInstructionTest, NestedFusionEquality) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - auto dot = builder.AddInstruction( - HloInstruction::CreateDot(data_shape, a, b_t, dot_dnums)); + auto dot = builder.AddInstruction(HloInstruction::CreateDot( + data_shape, a, b_t, dot_dnums, DefaultPrecisionConfig(2))); auto one = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto add_operand = builder.AddInstruction( - HloInstruction::CreateBroadcast(data_shape, one, {1})); + HloInstruction::CreateBroadcast(data_shape, one, {})); auto add = builder.AddInstruction(HloInstruction::CreateBinary( data_shape, HloOpcode::kAdd, dot, add_operand)); auto sub = builder.AddInstruction(HloInstruction::CreateBinary( @@ -1324,8 +1321,8 @@ TEST_F(HloInstructionTest, Stringification) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - HloInstruction* dot = builder.AddInstruction( - HloInstruction::CreateDot(sout, x, reshape, dot_dnums)); + HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot( + sout, x, reshape, dot_dnums, DefaultPrecisionConfig(2))); auto options = HloPrintOptions().set_print_metadata(false); @@ -1489,8 +1486,8 @@ TEST_F(HloInstructionTest, CanonnicalStringificationFusion) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - HloInstruction* dot = builder.AddInstruction( - HloInstruction::CreateDot(sout, x, reshape, dot_dnums)); + HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot( + sout, x, reshape, dot_dnums, DefaultPrecisionConfig(2))); auto options = HloPrintOptions().Canonical(); @@ -1531,8 +1528,8 @@ TEST_F(HloInstructionTest, CanonnicalStringificationWhile) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - HloInstruction* dot = builder.AddInstruction( - HloInstruction::CreateDot(sout, x, reshape, dot_dnums)); + HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot( + sout, x, reshape, dot_dnums, DefaultPrecisionConfig(2))); auto module = CreateNewModule(); auto* computation = module->AddEntryComputation(builder.Build()); @@ -1587,8 +1584,8 @@ TEST_F(HloInstructionTest, CanonnicalStringificationConditional) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - HloInstruction* dot = builder.AddInstruction( - HloInstruction::CreateDot(sout, x, reshape, dot_dnums)); + HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot( + sout, x, reshape, dot_dnums, DefaultPrecisionConfig(2))); auto module = CreateNewModule(); auto* computation = module->AddEntryComputation(builder.Build()); @@ -1743,5 +1740,23 @@ TEST_F(HloInstructionTest, CloneDnumsOnCustomCall) { << clone->convolution_dimension_numbers().DebugString(); } +TEST_F(HloInstructionTest, PreserveOperandPrecisionOnCloneConv) { + constexpr char kHloString[] = R"( + HloModule test_module + ENTRY test { + arg0 = f32[1,2,1] parameter(0) + arg1 = f32[1,1,1] parameter(1) + ROOT conv = f32[1,2,1] convolution(arg0, arg1), window={size=1}, + dim_labels=b0f_0io->b0f, operand_precision={high,default} + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(kHloString)); + auto* conv = module->entry_computation()->root_instruction(); + + auto clone = conv->Clone(); + EXPECT_THAT( + clone->precision_config().operand_precision(), + ::testing::ElementsAre(PrecisionConfig::HIGH, PrecisionConfig::DEFAULT)); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index a0de253eda729c4e8c3bf3bef3142e60c7a59c34..179ace2cdb76051fecdeb7e0cbdcd808bf9fee25 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" #include "absl/memory/memory.h" #include "absl/strings/escaping.h" #include "absl/strings/str_cat.h" @@ -27,8 +28,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h" #include "tensorflow/compiler/xla/window_util.h" -#include "tensorflow/core/lib/gtl/flatmap.h" namespace xla { namespace { @@ -47,6 +48,27 @@ bool IsInstructionElementwiseOnOperand(const HloInstruction* instruction, return instruction->IsElementwiseOnOperand(operand_index); }); } + +string PrecisionConfigToString(const PrecisionConfig& precision_config) { + if (absl::c_all_of(precision_config.operand_precision(), [](int32 precision) { + return static_cast(precision) == + PrecisionConfig::DEFAULT; + })) { + return ""; + } + + return StrCat( + "operand_precision={", + StrJoin( + precision_config.operand_precision(), ",", + [](string* out, int32 precision) { + CHECK(PrecisionConfig::Precision_IsValid(precision)) << precision; + StrAppend(out, + PrecisionToString( + static_cast(precision))); + }), + "}"); +} } // namespace HloBatchNormInstruction::HloBatchNormInstruction( @@ -91,8 +113,7 @@ HloBatchNormTrainingInstruction::HloBatchNormTrainingInstruction( std::unique_ptr HloBatchNormTrainingInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 3); return absl::make_unique( @@ -113,8 +134,7 @@ HloBatchNormInferenceInstruction::HloBatchNormInferenceInstruction( std::unique_ptr HloBatchNormInferenceInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 5); return absl::make_unique( @@ -135,8 +155,7 @@ HloBatchNormGradInstruction::HloBatchNormGradInstruction( std::unique_ptr HloBatchNormGradInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 5); return absl::make_unique( @@ -144,9 +163,9 @@ HloBatchNormGradInstruction::CloneWithNewOperandsImpl( new_operands[4], epsilon(), feature_index()); } -HloFftInstruction::HloFftInstruction( - const Shape& shape, HloInstruction* operand, FftType fft_type, - tensorflow::gtl::ArraySlice fft_length) +HloFftInstruction::HloFftInstruction(const Shape& shape, + HloInstruction* operand, FftType fft_type, + absl::Span fft_length) : HloInstruction(HloOpcode::kFft, shape), fft_type_(fft_type) { fft_length_.assign(fft_length.begin(), fft_length.end()); AppendOperand(operand); @@ -177,8 +196,7 @@ bool HloFftInstruction::IdenticalSlowPath( } std::unique_ptr HloFftInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 1); return absl::make_unique(shape, new_operands[0], fft_type_, @@ -196,6 +214,7 @@ HloSendRecvInstruction::HloSendRecvInstruction(HloOpcode opcode, HloInstructionProto HloSendRecvInstruction::ToProto() const { HloInstructionProto proto = HloInstruction::ToProto(); proto.set_channel_id(channel_id_); + proto.set_is_host_transfer(is_host_transfer_); return proto; } @@ -232,8 +251,7 @@ HloSendInstruction::HloSendInstruction(HloInstruction* operand, } std::unique_ptr HloSendInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 2); return absl::make_unique( @@ -250,8 +268,7 @@ HloSendDoneInstruction::HloSendDoneInstruction(HloSendInstruction* operand, std::unique_ptr HloSendDoneInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 1); return absl::make_unique( @@ -271,8 +288,7 @@ HloRecvInstruction::HloRecvInstruction(const Shape& shape, } std::unique_ptr HloRecvInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 1); return absl::make_unique( @@ -293,8 +309,7 @@ HloRecvDoneInstruction::HloRecvDoneInstruction(HloRecvInstruction* operand, std::unique_ptr HloRecvDoneInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 1); return absl::make_unique( @@ -303,7 +318,7 @@ HloRecvDoneInstruction::CloneWithNewOperandsImpl( HloCollectiveInstruction::HloCollectiveInstruction( HloOpcode opcode, const Shape& shape, - tensorflow::gtl::ArraySlice operands, + absl::Span operands, const std::vector& replica_groups) : HloInstruction(opcode, shape), replica_groups_(replica_groups) { for (auto operand : operands) { @@ -337,15 +352,14 @@ bool HloCollectiveInstruction::IdenticalSlowPath( /*eq_computations*/) const { const auto& casted_other = static_cast(other); - return ContainersEqual(replica_groups(), casted_other.replica_groups(), - [](const ReplicaGroup& a, const ReplicaGroup& b) { - return ContainersEqual(a.replica_ids(), - b.replica_ids()); - }); + return absl::c_equal(replica_groups(), casted_other.replica_groups(), + [](const ReplicaGroup& a, const ReplicaGroup& b) { + return absl::c_equal(a.replica_ids(), b.replica_ids()); + }); } HloAllReduceInstruction::HloAllReduceInstruction( - const Shape& shape, tensorflow::gtl::ArraySlice operands, + const Shape& shape, absl::Span operands, HloComputation* reduce_computation, const std::vector& replica_groups, absl::string_view barrier, const absl::optional& all_reduce_id) @@ -393,8 +407,7 @@ bool HloAllReduceInstruction::IdenticalSlowPath( std::unique_ptr HloAllReduceInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* /*context*/) const { return absl::make_unique( shape, new_operands, to_apply(), replica_groups(), @@ -402,23 +415,72 @@ HloAllReduceInstruction::CloneWithNewOperandsImpl( } HloAllToAllInstruction::HloAllToAllInstruction( - const Shape& shape, tensorflow::gtl::ArraySlice operands, + const Shape& shape, absl::Span operands, const std::vector& replica_groups) : HloCollectiveInstruction(HloOpcode::kAllToAll, shape, operands, replica_groups) {} std::unique_ptr HloAllToAllInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* /*context*/) const { return absl::make_unique(shape, new_operands, replica_groups()); } -HloReverseInstruction::HloReverseInstruction( +HloCollectivePermuteInstruction::HloCollectivePermuteInstruction( const Shape& shape, HloInstruction* operand, - tensorflow::gtl::ArraySlice dimensions) + const std::vector>& source_target_pairs) + : HloInstruction(HloOpcode::kCollectivePermute, shape), + source_target_pairs_(source_target_pairs) { + AppendOperand(operand); +} + +HloInstructionProto HloCollectivePermuteInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + for (const auto& pair : source_target_pairs()) { + auto* proto_pair = proto.add_source_target_pairs(); + proto_pair->set_source(pair.first); + proto_pair->set_target(pair.second); + } + return proto; +} + +std::vector +HloCollectivePermuteInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& /*options*/) const { + std::vector result; + std::vector strs; + for (const auto& pair : source_target_pairs()) { + strs.push_back(StrCat("{", pair.first, ",", pair.second, "}")); + } + result.push_back(StrCat("source_target_pairs={", StrJoin(strs, ","), "}")); + return result; +} + +bool HloCollectivePermuteInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + /*eq_computations*/) const { + const auto& casted_other = + static_cast(other); + return absl::c_equal(source_target_pairs(), + casted_other.source_target_pairs(), + [](const std::pair& a, + const std::pair& b) { return a == b; }); +} + +std::unique_ptr +HloCollectivePermuteInstruction::CloneWithNewOperandsImpl( + const Shape& shape, absl::Span new_operands, + HloCloneContext* /*context*/) const { + return absl::make_unique( + shape, new_operands[0], source_target_pairs()); +} + +HloReverseInstruction::HloReverseInstruction(const Shape& shape, + HloInstruction* operand, + absl::Span dimensions) : HloInstruction(HloOpcode::kReverse, shape), dimensions_(dimensions.begin(), dimensions.end()) { AppendOperand(operand); @@ -446,8 +508,7 @@ bool HloReverseInstruction::IdenticalSlowPath( } std::unique_ptr HloReverseInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 1); return absl::make_unique(shape, new_operands[0], @@ -455,7 +516,7 @@ std::unique_ptr HloReverseInstruction::CloneWithNewOperandsImpl( } HloConcatenateInstruction::HloConcatenateInstruction( - const Shape& shape, tensorflow::gtl::ArraySlice operands, + const Shape& shape, absl::Span operands, int64 dimension) : HloInstruction(HloOpcode::kConcatenate, shape), dimensions_({dimension}) { for (auto operand : operands) { @@ -487,16 +548,15 @@ bool HloConcatenateInstruction::IdenticalSlowPath( std::unique_ptr HloConcatenateInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { return absl::make_unique(shape, new_operands, dimensions(0)); } HloReduceInstruction::HloReduceInstruction( - const Shape& shape, tensorflow::gtl::ArraySlice args, - tensorflow::gtl::ArraySlice dimensions_to_reduce, + const Shape& shape, absl::Span args, + absl::Span dimensions_to_reduce, HloComputation* reduce_computation) : HloInstruction(HloOpcode::kReduce, shape), dimensions_(dimensions_to_reduce.begin(), dimensions_to_reduce.end()) { @@ -531,21 +591,20 @@ bool HloReduceInstruction::IdenticalSlowPath( } std::unique_ptr HloReduceInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { - CHECK_EQ(new_operands.size(), 2); + CHECK_EQ(new_operands.size() % 2, 0); return absl::make_unique(shape, new_operands, dimensions(), to_apply()); } HloSortInstruction::HloSortInstruction(const Shape& shape, int64 dimension, HloInstruction* keys, - HloInstruction* values) + absl::Span values) : HloInstruction(HloOpcode::kSort, shape), dimensions_({dimension}) { AppendOperand(keys); - if (values) { - AppendOperand(values); + for (auto* value : values) { + AppendOperand(value); } } @@ -571,28 +630,18 @@ bool HloSortInstruction::IdenticalSlowPath( } std::unique_ptr HloSortInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { HloInstruction* keys = new_operands[0]; - HloInstruction* values = new_operands.size() == 2 ? new_operands[1] : nullptr; return absl::make_unique(shape, dimensions(0), keys, - values); + new_operands.subspan(1)); } HloTransposeInstruction::HloTransposeInstruction( const Shape& shape, HloInstruction* operand, - tensorflow::gtl::ArraySlice dimensions) + absl::Span dimensions) : HloInstruction(HloOpcode::kTranspose, shape), dimensions_(dimensions.begin(), dimensions.end()) { - CHECK_EQ(shape.dimensions().size(), dimensions.size()); - CHECK_EQ(shape.dimensions().size(), operand->shape().dimensions().size()); - CHECK(std::equal(operand->shape().dimensions().begin(), - operand->shape().dimensions().end(), - Permute(dimensions, shape.dimensions()).begin())) - << "shape: " << ShapeUtil::HumanString(shape) - << ", operand->shape(): " << ShapeUtil::HumanString(shape) - << ", dimensions: {" << StrJoin(dimensions, ", ") << "}"; AppendOperand(operand); } @@ -626,8 +675,7 @@ bool HloTransposeInstruction::IdenticalSlowPath( std::unique_ptr HloTransposeInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 1); return absl::make_unique(shape, new_operands[0], @@ -636,7 +684,7 @@ HloTransposeInstruction::CloneWithNewOperandsImpl( HloBroadcastInstruction::HloBroadcastInstruction( const Shape& shape, HloInstruction* operand, - tensorflow::gtl::ArraySlice broadcast_dimension) + absl::Span broadcast_dimension) : HloInstruction(HloOpcode::kBroadcast, shape), dimensions_(broadcast_dimension.begin(), broadcast_dimension.end()) { AppendOperand(operand); @@ -665,17 +713,16 @@ bool HloBroadcastInstruction::IdenticalSlowPath( std::unique_ptr HloBroadcastInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 1); return absl::make_unique(shape, new_operands[0], dimensions()); } -HloMapInstruction::HloMapInstruction( - const Shape& shape, tensorflow::gtl::ArraySlice operands, - HloComputation* map_computation) +HloMapInstruction::HloMapInstruction(const Shape& shape, + absl::Span operands, + HloComputation* map_computation) : HloInstruction(HloOpcode::kMap, shape) { for (auto operand : operands) { AppendOperand(operand); @@ -724,17 +771,16 @@ bool HloMapInstruction::IdenticalSlowPath( } std::unique_ptr HloMapInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { return absl::make_unique(shape, new_operands, to_apply()); } -HloSliceInstruction::HloSliceInstruction( - const Shape& shape, HloInstruction* operand, - tensorflow::gtl::ArraySlice start_indices, - tensorflow::gtl::ArraySlice limit_indices, - tensorflow::gtl::ArraySlice strides) +HloSliceInstruction::HloSliceInstruction(const Shape& shape, + HloInstruction* operand, + absl::Span start_indices, + absl::Span limit_indices, + absl::Span strides) : HloInstruction(HloOpcode::kSlice, shape), slice_starts_(start_indices.begin(), start_indices.end()), slice_limits_(limit_indices.begin(), limit_indices.end()), @@ -785,16 +831,15 @@ bool HloSliceInstruction::IdenticalSlowPath( } std::unique_ptr HloSliceInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 1); return absl::make_unique( shape, new_operands[0], slice_starts_, slice_limits_, slice_strides_); } -HloConstantInstruction::HloConstantInstruction(std::unique_ptr literal) - : HloInstruction(HloOpcode::kConstant, CHECK_NOTNULL(literal)->shape()), +HloConstantInstruction::HloConstantInstruction(Literal literal) + : HloInstruction(HloOpcode::kConstant, literal.shape()), literal_(std::move(literal)) {} HloConstantInstruction::HloConstantInstruction(const Shape& shape) @@ -802,7 +847,7 @@ HloConstantInstruction::HloConstantInstruction(const Shape& shape) HloInstructionProto HloConstantInstruction::ToProto() const { HloInstructionProto proto = HloInstruction::ToProto(); - if (literal_ != nullptr) { + if (literal_.has_value()) { *proto.mutable_literal() = literal_->ToProto(); } return proto; @@ -824,7 +869,7 @@ void HloConstantInstruction::RelayoutConstant(const Layout& new_layout, if (!mutable_array_subshape->has_layout() || !LayoutUtil::Equal(mutable_array_subshape->layout(), new_layout)) { - literal_ = literal_->Relayout(new_layout, shape_index); + *literal_ = literal_->Relayout(new_layout, shape_index); *mutable_array_subshape->mutable_layout() = new_layout; } } @@ -839,10 +884,10 @@ bool HloConstantInstruction::IdenticalSlowPath( std::unique_ptr HloConstantInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { - return absl::make_unique(literal_->CloneToUnique()); + CHECK(literal_.has_value()); + return absl::make_unique(literal_->Clone()); } string HloConstantInstruction::OperandsToStringWithCanonicalNameMap( @@ -850,7 +895,7 @@ string HloConstantInstruction::OperandsToStringWithCanonicalNameMap( CanonicalNameMap* canonical_name_map) const { string operands; // For constants, show the actual value in place of an empty operand list. - if (literal_ != nullptr && + if (literal_.has_value() && ((ShapeUtil::IsArray(shape()) && ShapeUtil::ElementsIn(shape()) <= 10) || options.print_large_constants())) { // Literal::ToString emits multidimensional arrays over multiple @@ -885,7 +930,7 @@ HloTraceInstruction::HloTraceInstruction(const string& tag, HloInstructionProto HloTraceInstruction::ToProto() const { HloInstructionProto proto = HloInstruction::ToProto(); - *proto.mutable_literal() = literal_->ToProto(); + *proto.mutable_literal() = literal_.ToProto(); return proto; } @@ -897,8 +942,7 @@ bool HloTraceInstruction::IdenticalSlowPath( } std::unique_ptr HloTraceInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { LOG(FATAL) << "Not yet implemented, clone: " << HloOpcodeString(opcode()); } @@ -916,7 +960,7 @@ HloFusionInstruction::HloFusionInstruction(const Shape& shape, HloFusionInstruction::HloFusionInstruction( const Shape& shape, FusionKind fusion_kind, - tensorflow::gtl::ArraySlice operands, + absl::Span operands, HloComputation* fusion_computation) : HloInstruction(HloOpcode::kFusion, shape), fusion_kind_(fusion_kind) { for (auto operand : operands) { @@ -991,7 +1035,8 @@ HloInstruction* HloFusionInstruction::AddFusionOperand( const int64 param_no = operand_count(); // Name the parameter after the instruction it represents in the outer // (non-fusion) computation. - string param_name = StrCat(new_operand->name(), ".param_", param_no); + // string param_name = StrCat(new_operand->name(), ".param_", param_no); + string param_name = StrCat("param_", param_no); HloInstruction* fused_parameter = fused_instructions_computation()->AddParameter( HloInstruction::CreateParameter(param_no, new_operand->shape(), @@ -1047,7 +1092,7 @@ void HloFusionInstruction::MergeFusionInstructionIntoMultiOutput( // Note that we add the unfused instructions to this->parent_ computation. // This is necessary because the unique_id needs for an instruction and // it's only added when inserting to the computation. - tensorflow::gtl::FlatMap old_to_new; + absl::flat_hash_map old_to_new; std::vector unfused_instructions; auto computation_to_merge = instruction_to_merge->fused_instructions_computation(); @@ -1152,7 +1197,7 @@ HloInstruction* HloFusionInstruction::FuseInstructionInternal( HloInstruction* HloFusionInstruction::CloneAndFuseInternal( HloInstruction* instruction_to_fuse, bool add_output) { - CHECK(instruction_to_fuse->IsFusable()) << instruction_to_fuse->ToString(); + CHECK(instruction_to_fuse->IsFusible()) << instruction_to_fuse->ToString(); VLOG(3) << "CloneAndFuseInternal:\n" << instruction_to_fuse->ToString(); HloInstruction* clone = nullptr; if (called_computations().empty()) { @@ -1323,8 +1368,7 @@ bool HloFusionInstruction::IdenticalSlowPath( } std::unique_ptr HloFusionInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { HloModule* module = context != nullptr ? context->module() : GetModule(); HloComputation* new_fused_computation = nullptr; @@ -1341,7 +1385,7 @@ std::unique_ptr HloFusionInstruction::CloneWithNewOperandsImpl( } Status HloFusionInstruction::DeduplicateFusionOperands() { - tensorflow::gtl::FlatMap operand_indices; + absl::flat_hash_map operand_indices; std::vector operands_to_remove; for (int i = 0; i < operand_count(); ++i) { auto emplace_result = operand_indices.emplace(operand(i), i); @@ -1362,7 +1406,7 @@ Status HloFusionInstruction::DeduplicateFusionOperands() { HloRngInstruction::HloRngInstruction( const Shape& shape, RandomDistribution distribution, - tensorflow::gtl::ArraySlice parameters) + absl::Span parameters) : HloInstruction(HloOpcode::kRng, shape), distribution_(distribution) { for (HloInstruction* param : parameters) { AppendOperand(param); @@ -1393,8 +1437,7 @@ bool HloRngInstruction::IdenticalSlowPath( } std::unique_ptr HloRngInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { return absl::make_unique(shape, distribution_, new_operands); @@ -1430,8 +1473,7 @@ bool HloParameterInstruction::IdenticalSlowPath( std::unique_ptr HloParameterInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { return absl::make_unique(parameter_number_, shape, name()); @@ -1440,7 +1482,6 @@ HloParameterInstruction::CloneWithNewOperandsImpl( HloGetTupleElementInstruction::HloGetTupleElementInstruction( const Shape& shape, HloInstruction* operand, int64 index) : HloInstruction(HloOpcode::kGetTupleElement, shape), tuple_index_(index) { - CHECK(ShapeUtil::IsTuple(operand->shape())); AppendOperand(operand); } @@ -1466,8 +1507,7 @@ bool HloGetTupleElementInstruction::IdenticalSlowPath( std::unique_ptr HloGetTupleElementInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 1); return absl::make_unique( @@ -1509,8 +1549,7 @@ bool HloReducePrecisionInstruction::IdenticalSlowPath( std::unique_ptr HloReducePrecisionInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 1); return absl::make_unique( @@ -1550,8 +1589,7 @@ bool HloInfeedInstruction::IdenticalSlowPath( } std::unique_ptr HloInfeedInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 1); return absl::make_unique( @@ -1565,9 +1603,6 @@ HloOutfeedInstruction::HloOutfeedInstruction(const Shape& outfeed_shape, : HloInstruction(HloOpcode::kOutfeed, ShapeUtil::MakeTokenShape()), outfeed_shape_(outfeed_shape), outfeed_config_(outfeed_config) { - CHECK(ShapeUtil::Compatible(operand->shape(), outfeed_shape)) - << "Outfeed shape " << outfeed_shape - << " must be compatible with operand shape " << operand->shape(); AppendOperand(operand); AppendOperand(token_operand); } @@ -1596,8 +1631,7 @@ bool HloOutfeedInstruction::IdenticalSlowPath( } std::unique_ptr HloOutfeedInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 2); return absl::make_unique( @@ -1606,12 +1640,14 @@ std::unique_ptr HloOutfeedInstruction::CloneWithNewOperandsImpl( HloConvolutionInstruction::HloConvolutionInstruction( const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, - const Window& window, const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count) + int64 feature_group_count, const Window& window, + const ConvolutionDimensionNumbers& dimension_numbers, + const PrecisionConfig& precision_config) : HloInstruction(HloOpcode::kConvolution, shape), + feature_group_count_(feature_group_count), window_(window), convolution_dimension_numbers_(dimension_numbers), - feature_group_count_(feature_group_count) { + precision_config_(precision_config) { if (window_util::HasBaseDilation(window)) { SetAndSanitizeName(StrCat(name(), "-base-dilated")); } @@ -1638,6 +1674,8 @@ HloInstructionProto HloConvolutionInstruction::ToProto() const { *proto.mutable_window() = window_; *proto.mutable_convolution_dimension_numbers() = convolution_dimension_numbers_; + proto.set_feature_group_count(feature_group_count_); + *proto.mutable_precision_config() = precision_config_; return proto; } @@ -1649,7 +1687,15 @@ std::vector HloConvolutionInstruction::ExtraAttributesToStringImpl( } extra.push_back(StrCat("dim_labels=", ConvolutionDimensionNumbersToString( convolution_dimension_numbers_))); - extra.push_back(StrCat("feature_group_count=", feature_group_count_)); + if (feature_group_count_ != 1) { + extra.push_back(StrCat("feature_group_count=", feature_group_count_)); + } + + string precision_config_string = PrecisionConfigToString(precision_config_); + if (!precision_config_string.empty()) { + extra.push_back(precision_config_string); + } + return extra; } @@ -1659,21 +1705,25 @@ bool HloConvolutionInstruction::IdenticalSlowPath( eq_computations) const { const auto& casted_other = static_cast(other); + if (feature_group_count_ != other.feature_group_count()) { + return false; + } return protobuf_util::ProtobufEquals(window(), casted_other.window()) && protobuf_util::ProtobufEquals( convolution_dimension_numbers(), - casted_other.convolution_dimension_numbers()); + casted_other.convolution_dimension_numbers()) && + protobuf_util::ProtobufEquals(precision_config(), + casted_other.precision_config()); } std::unique_ptr HloConvolutionInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 2); return absl::make_unique( - shape, new_operands[0], new_operands[1], window(), - convolution_dimension_numbers_, feature_group_count_); + shape, new_operands[0], new_operands[1], feature_group_count_, window(), + convolution_dimension_numbers_, precision_config_); } HloReduceWindowInstruction::HloReduceWindowInstruction( @@ -1712,8 +1762,7 @@ bool HloReduceWindowInstruction::IdenticalSlowPath( std::unique_ptr HloReduceWindowInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 2); return absl::make_unique( @@ -1761,8 +1810,7 @@ bool HloSelectAndScatterInstruction::IdenticalSlowPath( std::unique_ptr HloSelectAndScatterInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 3); return absl::make_unique( @@ -1771,11 +1819,29 @@ HloSelectAndScatterInstruction::CloneWithNewOperandsImpl( } HloCustomCallInstruction::HloCustomCallInstruction( - const Shape& shape, tensorflow::gtl::ArraySlice operands, - absl::string_view custom_call_target) + const Shape& shape, absl::Span operands, + absl::string_view custom_call_target, absl::string_view opaque) : HloInstruction(HloOpcode::kCustomCall, shape), - custom_call_target_(custom_call_target.begin(), - custom_call_target.end()) { + custom_call_target_(custom_call_target.begin(), custom_call_target.end()), + opaque_(opaque.begin(), opaque.end()), + feature_group_count_(1), + layout_constrained_(false) { + for (auto operand : operands) { + AppendOperand(operand); + } +} + +HloCustomCallInstruction::HloCustomCallInstruction( + const Shape& shape, absl::Span operands, + absl::string_view custom_call_target, absl::string_view opaque, + absl::Span operand_shapes_with_layout) + : HloInstruction(HloOpcode::kCustomCall, shape), + custom_call_target_(custom_call_target.begin(), custom_call_target.end()), + opaque_(opaque.begin(), opaque.end()), + feature_group_count_(1), + layout_constrained_(true), + operand_shapes_with_layout_(operand_shapes_with_layout.begin(), + operand_shapes_with_layout.end()) { for (auto operand : operands) { AppendOperand(operand); } @@ -1791,6 +1857,14 @@ HloInstructionProto HloCustomCallInstruction::ToProto() const { *convolution_dimension_numbers_; } proto.set_custom_call_target(custom_call_target_); + proto.set_custom_call_opaque(opaque_); + proto.set_feature_group_count(feature_group_count_); + if (layout_constrained()) { + proto.set_constrain_layout(true); + for (const Shape& shape : operand_shapes_with_layout_) { + *proto.add_operand_shapes_with_layout() = shape; + } + } return proto; } @@ -1805,11 +1879,27 @@ std::vector HloCustomCallInstruction::ExtraAttributesToStringImpl( "dim_labels=", ConvolutionDimensionNumbersToString(*convolution_dimension_numbers_))); } + if (feature_group_count_ != 1) { + extra.push_back(StrCat("feature_group_count=", feature_group_count_)); + } // By contract, we print the custom call target even if // options.print_subcomputation_mode() == kOff, because the call target is not // an HloComputation. extra.push_back( StrCat("custom_call_target=\"", CEscape(custom_call_target_), "\"")); + // If the opaque string becomes enormous we may want to reconsider printing + // this inline and consider other options. + if (!opaque_.empty()) { + extra.push_back(StrCat("opaque=\"", CEscape(opaque_), "\"")); + } + if (layout_constrained()) { + std::vector shape_strings; + for (const Shape& shape : operand_shapes_with_layout_) { + shape_strings.push_back(ShapeUtil::HumanStringWithLayout(shape)); + } + extra.push_back(StrCat("operand_layout_constraints={", + StrJoin(shape_strings, ", "), "}")); + } return extra; } @@ -1832,22 +1922,26 @@ bool HloCustomCallInstruction::IdenticalSlowPath( casted_other.convolution_dimension_numbers()))) { return false; } - return custom_call_target_ == casted_other.custom_call_target_; + if (feature_group_count_ != casted_other.feature_group_count_) { + return false; + } + return custom_call_target_ == casted_other.custom_call_target_ && + opaque_ == casted_other.opaque_; } std::unique_ptr HloCustomCallInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { auto cloned = absl::make_unique( - shape, new_operands, custom_call_target()); + shape, new_operands, custom_call_target(), opaque()); if (window_ != nullptr) { cloned->set_window(*window_); } if (convolution_dimension_numbers_ != nullptr) { cloned->set_convolution_dimension_numbers(*convolution_dimension_numbers_); } + cloned->set_feature_group_count(feature_group_count_); return std::move(cloned); } @@ -1881,8 +1975,7 @@ bool HloPadInstruction::IdenticalSlowPath( } std::unique_ptr HloPadInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 2); return absl::make_unique(shape, new_operands[0], @@ -1891,7 +1984,7 @@ std::unique_ptr HloPadInstruction::CloneWithNewOperandsImpl( HloDynamicSliceInstruction::HloDynamicSliceInstruction( const Shape& shape, HloInstruction* operand, HloInstruction* start_indices, - tensorflow::gtl::ArraySlice slice_sizes) + absl::Span slice_sizes) : HloInstruction(HloOpcode::kDynamicSlice, shape), dynamic_slice_sizes_(slice_sizes.begin(), slice_sizes.end()) { AppendOperand(operand); @@ -1921,8 +2014,7 @@ bool HloDynamicSliceInstruction::IdenticalSlowPath( std::unique_ptr HloDynamicSliceInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 2); return absl::make_unique( @@ -1932,7 +2024,7 @@ HloDynamicSliceInstruction::CloneWithNewOperandsImpl( HloGatherInstruction::HloGatherInstruction( const Shape& shape, HloInstruction* operand, HloInstruction* start_indices, const GatherDimensionNumbers& gather_dim_numbers, - tensorflow::gtl::ArraySlice slice_sizes) + absl::Span slice_sizes) : HloInstruction(HloOpcode::kGather, shape) { AppendOperand(operand); AppendOperand(start_indices); @@ -1961,10 +2053,9 @@ string HloGatherInstruction::GatherDimensionNumbersToString() const { } /* static */ GatherDimensionNumbers HloGatherInstruction::MakeGatherDimNumbers( - tensorflow::gtl::ArraySlice offset_dims, - tensorflow::gtl::ArraySlice collapsed_slice_dims, - tensorflow::gtl::ArraySlice start_index_map, - int64 index_vector_dim) { + absl::Span offset_dims, + absl::Span collapsed_slice_dims, + absl::Span start_index_map, int64 index_vector_dim) { GatherDimensionNumbers gather_dim_numbers; for (int64 output_window_dim : offset_dims) { gather_dim_numbers.add_offset_dims(output_window_dim); @@ -2007,8 +2098,7 @@ bool HloGatherInstruction::IdenticalSlowPath( } std::unique_ptr HloGatherInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 2); return absl::make_unique( @@ -2052,9 +2142,9 @@ string HloScatterInstruction::ScatterDimensionNumbersToString() const { /* static */ ScatterDimensionNumbers HloScatterInstruction::MakeScatterDimNumbers( - tensorflow::gtl::ArraySlice update_window_dims, - tensorflow::gtl::ArraySlice inserted_window_dims, - tensorflow::gtl::ArraySlice scatter_dims_to_operand_dims, + absl::Span update_window_dims, + absl::Span inserted_window_dims, + absl::Span scatter_dims_to_operand_dims, int64 index_vector_dim) { ScatterDimensionNumbers scatter_dim_numbers; for (int64 update_window_dim : update_window_dims) { @@ -2094,8 +2184,7 @@ bool HloScatterInstruction::IdenticalSlowPath( } std::unique_ptr HloScatterInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 3); return absl::make_unique( @@ -2103,4 +2192,161 @@ std::unique_ptr HloScatterInstruction::CloneWithNewOperandsImpl( scatter_dimension_numbers()); } +HloIotaInstruction::HloIotaInstruction(const Shape& shape, int64 iota_dimension) + : HloInstruction(HloOpcode::kIota, shape), + iota_dimension_(iota_dimension) {} + +HloInstructionProto HloIotaInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + proto.add_dimensions(iota_dimension()); + return proto; +} + +std::vector HloIotaInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {StrCat("iota_dimension=", iota_dimension())}; +} + +bool HloIotaInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = static_cast(other); + return iota_dimension() == casted_other.iota_dimension(); +} + +std::unique_ptr HloIotaInstruction::CloneWithNewOperandsImpl( + const Shape& shape, absl::Span new_operands, + HloCloneContext* context) const { + return absl::make_unique(shape, iota_dimension()); +} + +HloDotInstruction::HloDotInstruction( + const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, + const DotDimensionNumbers& dimension_numbers, + const PrecisionConfig& precision_config) + : HloInstruction(HloOpcode::kDot, shape), + dot_dimension_numbers_(dimension_numbers), + precision_config_(precision_config) { + AppendOperand(lhs); + AppendOperand(rhs); +} + +HloInstructionProto HloDotInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + *proto.mutable_dot_dimension_numbers() = dot_dimension_numbers_; + *proto.mutable_precision_config() = precision_config_; + return proto; +} + +std::vector HloDotInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + std::vector extra = {DotDimensionNumbersToString()}; + + string precision_config_string = PrecisionConfigToString(precision_config_); + if (!precision_config_string.empty()) { + extra.push_back(precision_config_string); + } + return extra; +} + +bool HloDotInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = static_cast(other); + return protobuf_util::ProtobufEquals(dot_dimension_numbers(), + casted_other.dot_dimension_numbers()) && + protobuf_util::ProtobufEquals(precision_config(), + casted_other.precision_config()); +} + +std::unique_ptr HloDotInstruction::CloneWithNewOperandsImpl( + const Shape& shape, absl::Span new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 2); + return absl::make_unique( + shape, new_operands[0], new_operands[1], dot_dimension_numbers_, + precision_config_); +} + +string HloDotInstruction::DotDimensionNumbersToString() const { + std::vector result; + const DotDimensionNumbers& dnums = dot_dimension_numbers_; + if (!dnums.lhs_batch_dimensions().empty()) { + result.push_back(StrCat("lhs_batch_dims={", + StrJoin(dnums.lhs_batch_dimensions(), ","), "}")); + } + result.push_back(StrCat("lhs_contracting_dims={", + StrJoin(dnums.lhs_contracting_dimensions(), ","), + "}")); + + if (!dnums.rhs_batch_dimensions().empty()) { + result.push_back(StrCat("rhs_batch_dims={", + StrJoin(dnums.rhs_batch_dimensions(), ","), "}")); + } + result.push_back(StrCat("rhs_contracting_dims={", + StrJoin(dnums.rhs_contracting_dimensions(), ","), + "}")); + + return StrJoin(result, ", "); +} + +HloDomainInstruction::HloDomainInstruction( + const Shape& shape, HloInstruction* operand, + std::unique_ptr operand_side_metadata, + std::unique_ptr user_side_metadata) + : HloInstruction(HloOpcode::kDomain, shape), + operand_side_metadata_(std::move(operand_side_metadata)), + user_side_metadata_(std::move(user_side_metadata)) { + AppendOperand(operand); +} + +std::vector HloDomainInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + if (operand_side_metadata_ != nullptr && user_side_metadata_ != nullptr) { + return {StrCat("domain={kind=\"", operand_side_metadata_->Kind(), + "\", entry=", user_side_metadata_->ToString(), + ", exit=", operand_side_metadata_->ToString(), "}")}; + } + return {}; +} + +bool HloDomainInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = static_cast(other); + return operand_side_metadata().Matches( + casted_other.operand_side_metadata()) && + user_side_metadata().Matches(casted_other.user_side_metadata()); +} + +std::unique_ptr HloDomainInstruction::CloneWithNewOperandsImpl( + const Shape& shape, absl::Span new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 1); + return absl::make_unique( + shape, new_operands[0], operand_side_metadata_->Clone(), + user_side_metadata_->Clone()); +} + +HloInstructionProto HloDomainInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + auto operand_side_sharding = + dynamic_cast(operand_side_metadata_.get()); + if (operand_side_sharding) { + *proto.mutable_domain_entry_sharding() = + operand_side_sharding->sharding()->ToProto(); + } + + auto user_side_sharding = + dynamic_cast(user_side_metadata_.get()); + if (user_side_sharding) { + *proto.mutable_domain_exit_sharding() = + user_side_sharding->sharding()->ToProto(); + } + + return proto; +} } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h index efdb9e97819b07ba67075586df227273b8b36f24..5f06dc093248e1d4d36ec845ced1e68c2b9d0752 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.h +++ b/tensorflow/compiler/xla/service/hlo_instructions.h @@ -67,8 +67,7 @@ class HloBatchNormTrainingInstruction : public HloBatchNormInstruction { private: // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; }; @@ -82,8 +81,7 @@ class HloBatchNormInferenceInstruction : public HloBatchNormInstruction { private: // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; }; @@ -97,8 +95,7 @@ class HloBatchNormGradInstruction : public HloBatchNormInstruction { private: // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; }; @@ -106,7 +103,7 @@ class HloFftInstruction : public HloInstruction { public: explicit HloFftInstruction(const Shape& shape, HloInstruction* operand, FftType fft_type, - tensorflow::gtl::ArraySlice fft_length); + absl::Span fft_length); FftType fft_type() const { return fft_type_; } const std::vector& fft_length() const { return fft_length_; } @@ -124,8 +121,7 @@ class HloFftInstruction : public HloInstruction { // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; // Describes FFT type for an FFT instruction. @@ -174,8 +170,7 @@ class HloSendInstruction : public HloSendRecvInstruction { private: // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; }; @@ -187,8 +182,7 @@ class HloSendDoneInstruction : public HloSendRecvInstruction { private: // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; }; @@ -200,8 +194,7 @@ class HloRecvInstruction : public HloSendRecvInstruction { private: // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; }; @@ -213,8 +206,7 @@ class HloRecvDoneInstruction : public HloSendRecvInstruction { private: // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; }; @@ -227,7 +219,7 @@ class HloCollectiveInstruction : public HloInstruction { protected: explicit HloCollectiveInstruction( HloOpcode opcode, const Shape& shape, - tensorflow::gtl::ArraySlice operands, + absl::Span operands, const std::vector& replica_groups); HloInstructionProto ToProto() const override; @@ -245,7 +237,7 @@ class HloCollectiveInstruction : public HloInstruction { class HloAllReduceInstruction : public HloCollectiveInstruction { public: explicit HloAllReduceInstruction( - const Shape& shape, tensorflow::gtl::ArraySlice operands, + const Shape& shape, absl::Span operands, HloComputation* reduce_computation, const std::vector& replica_groups, absl::string_view barrier, const absl::optional& all_reduce_id); @@ -274,8 +266,7 @@ class HloAllReduceInstruction : public HloCollectiveInstruction { // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; // The string representation of the barrier config used for CrossReplicaSum. @@ -290,21 +281,49 @@ class HloAllReduceInstruction : public HloCollectiveInstruction { class HloAllToAllInstruction : public HloCollectiveInstruction { public: explicit HloAllToAllInstruction( - const Shape& shape, tensorflow::gtl::ArraySlice operand, + const Shape& shape, absl::Span operands, const std::vector& replica_groups); private: // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; }; +class HloCollectivePermuteInstruction : public HloInstruction { + public: + explicit HloCollectivePermuteInstruction( + const Shape& shape, HloInstruction* operand, + const std::vector>& source_target_pairs); + + const std::vector>& source_target_pairs() const { + return source_target_pairs_; + } + + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, absl::Span new_operands, + HloCloneContext* context) const override; + + const std::vector> source_target_pairs_; +}; + class HloReverseInstruction : public HloInstruction { public: explicit HloReverseInstruction(const Shape& shape, HloInstruction* operand, - tensorflow::gtl::ArraySlice dimensions); + absl::Span dimensions); // Returns the dimension sizes or numbers associated with this instruction. const std::vector& dimensions() const override { return dimensions_; } int64 dimensions(int64 index) const override { return dimensions()[index]; } @@ -320,8 +339,7 @@ class HloReverseInstruction : public HloInstruction { eq_computations) const override; // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; std::vector dimensions_; @@ -329,9 +347,9 @@ class HloReverseInstruction : public HloInstruction { class HloConcatenateInstruction : public HloInstruction { public: - explicit HloConcatenateInstruction( - const Shape& shape, tensorflow::gtl::ArraySlice operands, - int64 dimension); + explicit HloConcatenateInstruction(const Shape& shape, + absl::Span operands, + int64 dimension); // Returns the dimension sizes or numbers associated with this instruction. const std::vector& dimensions() const override { return dimensions_; } int64 dimensions(int64 index) const override { return dimensions()[index]; } @@ -349,8 +367,7 @@ class HloConcatenateInstruction : public HloInstruction { eq_computations) const override; // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; std::vector dimensions_; @@ -358,26 +375,28 @@ class HloConcatenateInstruction : public HloInstruction { class HloReduceInstruction : public HloInstruction { public: - explicit HloReduceInstruction( - const Shape& shape, tensorflow::gtl::ArraySlice args, - tensorflow::gtl::ArraySlice dimensions_to_reduce, - HloComputation* reduce_computation); + explicit HloReduceInstruction(const Shape& shape, + absl::Span args, + absl::Span dimensions_to_reduce, + HloComputation* reduce_computation); // Returns the dimension sizes or numbers associated with this instruction. const std::vector& dimensions() const override { return dimensions_; } int64 dimensions(int64 index) const override { return dimensions()[index]; } // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; + // Returns the number of input arrays (and, consequentially, the number of + // init values) this reduce has. + int64 input_count() const { return operand_count() / 2; } + // Returns the input tensors to be reduced. - tensorflow::gtl::ArraySlice inputs() const { - return tensorflow::gtl::ArraySlice(operands(), 0, - operand_count() / 2); + absl::Span inputs() const { + return absl::MakeSpan(operands()).subspan(0, input_count()); } // Returns the init values of the reduction. - tensorflow::gtl::ArraySlice init_values() const { - return tensorflow::gtl::ArraySlice( - operands(), operand_count() / 2, operand_count()); + absl::Span init_values() const { + return absl::MakeSpan(operands()).subspan(input_count(), operand_count()); } private: @@ -389,8 +408,7 @@ class HloReduceInstruction : public HloInstruction { eq_computations) const override; // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; std::vector dimensions_; @@ -400,14 +418,19 @@ class HloSortInstruction : public HloInstruction { public: explicit HloSortInstruction(const Shape& shape, int64 dimension, HloInstruction* keys, - HloInstruction* values = nullptr); + absl::Span values = {}); // Returns the dimension sizes or numbers associated with this instruction. const std::vector& dimensions() const override { return dimensions_; } int64 dimensions(int64 index) const override { return dimensions()[index]; } // Returns the sort dimension for this instruction - int64 sort_dimension() { return dimensions(0); } + int64 sort_dimension() const { return dimensions(0); } // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; + // Returns the key operand to this instruction. + const HloInstruction* keys() const { return operand(0); } + HloInstruction* mutable_keys() { return mutable_operand(0); } + // Returns the number of value operands. + int64 values_count() const { return operand_count() - 1; } private: std::vector ExtraAttributesToStringImpl( @@ -418,8 +441,7 @@ class HloSortInstruction : public HloInstruction { eq_computations) const override; // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; std::vector dimensions_; @@ -427,9 +449,8 @@ class HloSortInstruction : public HloInstruction { class HloTransposeInstruction : public HloInstruction { public: - explicit HloTransposeInstruction( - const Shape& shape, HloInstruction* operand, - tensorflow::gtl::ArraySlice dimensions); + explicit HloTransposeInstruction(const Shape& shape, HloInstruction* operand, + absl::Span dimensions); // Returns the dimension sizes or numbers associated with this instruction. const std::vector& dimensions() const override { return dimensions_; } int64 dimensions(int64 index) const override { return dimensions()[index]; } @@ -447,8 +468,7 @@ class HloTransposeInstruction : public HloInstruction { eq_computations) const override; // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; std::vector dimensions_; @@ -456,9 +476,8 @@ class HloTransposeInstruction : public HloInstruction { class HloBroadcastInstruction : public HloInstruction { public: - explicit HloBroadcastInstruction( - const Shape& shape, HloInstruction* operand, - tensorflow::gtl::ArraySlice broadcast_dimension); + explicit HloBroadcastInstruction(const Shape& shape, HloInstruction* operand, + absl::Span broadcast_dimension); // Returns the dimension sizes or numbers associated with this instruction. const std::vector& dimensions() const override { return dimensions_; } int64 dimensions(int64 index) const override { return dimensions()[index]; } @@ -474,8 +493,7 @@ class HloBroadcastInstruction : public HloInstruction { eq_computations) const override; // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; std::vector dimensions_; @@ -483,9 +501,9 @@ class HloBroadcastInstruction : public HloInstruction { class HloMapInstruction : public HloInstruction { public: - explicit HloMapInstruction( - const Shape& shape, tensorflow::gtl::ArraySlice operands, - HloComputation* map_computation); + explicit HloMapInstruction(const Shape& shape, + absl::Span operands, + HloComputation* map_computation); // Returns the dimension sizes or numbers associated with this instruction. const std::vector& dimensions() const override { return dimensions_; } int64 dimensions(int64 index) const override { return dimensions()[index]; } @@ -503,8 +521,7 @@ class HloMapInstruction : public HloInstruction { eq_computations) const override; // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; std::vector dimensions_; @@ -513,9 +530,9 @@ class HloMapInstruction : public HloInstruction { class HloSliceInstruction : public HloInstruction { public: explicit HloSliceInstruction(const Shape& shape, HloInstruction* operand, - tensorflow::gtl::ArraySlice start_indices, - tensorflow::gtl::ArraySlice limit_indices, - tensorflow::gtl::ArraySlice strides); + absl::Span start_indices, + absl::Span limit_indices, + absl::Span strides); HloInstructionProto ToProto() const override; @@ -534,17 +551,6 @@ class HloSliceInstruction : public HloInstruction { } const std::vector& slice_strides() const { return slice_strides_; } - // Returns the flag that describes whether a slice must be lowered into an - // offset into the original operand. - bool IsInPlaceSlice() const { return is_in_place_slice_; } - - // Sets and returns the flag that describes whether a slice must be lowered - // into an offset into the original operand. - bool SetIsInPlaceSlice(bool value) { - is_in_place_slice_ = value; - return value; - } - private: std::vector ExtraAttributesToStringImpl( const HloPrintOptions& options) const override; @@ -554,28 +560,24 @@ class HloSliceInstruction : public HloInstruction { eq_computations) const override; // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; // Describes the [begin, end) index range for a slice. std::vector slice_starts_; std::vector slice_limits_; std::vector slice_strides_; - - // Describes whether the slice can be lowered to an offset into the operand. - bool is_in_place_slice_ = false; }; class HloConstantInstruction : public HloInstruction { public: - explicit HloConstantInstruction(std::unique_ptr literal); + explicit HloConstantInstruction(Literal literal); // Used when the literal is too large and dropped. explicit HloConstantInstruction(const Shape& shape); // Returns the literal associated with this instruction. const Literal& literal() const { return *literal_; } // Returns whether there is literal associated with this instruction. - bool HasLiteral() const { return literal_ != nullptr; } + bool HasLiteral() const { return literal_.has_value(); } // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; @@ -597,18 +599,16 @@ class HloConstantInstruction : public HloInstruction { CanonicalNameMap* canonical_name_map) const override; // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; - // TODO(b/36360764): Remove unique_ptr wrapping. - std::unique_ptr literal_; + absl::optional literal_; }; class HloTraceInstruction : public HloInstruction { public: explicit HloTraceInstruction(const string& tag, HloInstruction* operand); // Returns a tag to be used in tracing. - string TracingTag() const { return literal_->GetR1U8AsString(); } + string TracingTag() const { return literal_.GetR1U8AsString(); } // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; @@ -619,11 +619,9 @@ class HloTraceInstruction : public HloInstruction { eq_computations) const override; // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; - // TODO(b/36360764): Remove unique_ptr wrapping. - std::unique_ptr literal_; + Literal literal_; }; class HloFusionInstruction : public HloInstruction { @@ -631,10 +629,9 @@ class HloFusionInstruction : public HloInstruction { explicit HloFusionInstruction(const Shape& shape, FusionKind fusion_kind, HloInstruction* fused_root); - explicit HloFusionInstruction( - const Shape& shape, FusionKind fusion_kind, - tensorflow::gtl::ArraySlice operands, - HloComputation* fusion_computation); + explicit HloFusionInstruction(const Shape& shape, FusionKind fusion_kind, + absl::Span operands, + HloComputation* fusion_computation); string ToCategory() const override; // Returns a serialized representation of this instruction. @@ -747,8 +744,7 @@ class HloFusionInstruction : public HloInstruction { eq_computations) const override; // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; // The type of the fusion. Used by kFusion only. @@ -757,9 +753,9 @@ class HloFusionInstruction : public HloInstruction { class HloRngInstruction : public HloInstruction { public: - explicit HloRngInstruction( - const Shape& shape, RandomDistribution distribution, - tensorflow::gtl::ArraySlice parameters); + explicit HloRngInstruction(const Shape& shape, + RandomDistribution distribution, + absl::Span parameters); // Returns the random distribution for this rng node. RandomDistribution random_distribution() const { return distribution_; } // Returns a serialized representation of this instruction. @@ -776,8 +772,7 @@ class HloRngInstruction : public HloInstruction { eq_computations) const override; // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; // The distribution requested for random number generation. @@ -802,8 +797,7 @@ class HloParameterInstruction : public HloInstruction { CanonicalNameMap* canonical_name_map) const override; // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; int64 parameter_number_ = 0; @@ -827,8 +821,7 @@ class HloGetTupleElementInstruction : public HloInstruction { eq_computations) const override; // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; int64 tuple_index_ = -1; @@ -856,8 +849,7 @@ class HloReducePrecisionInstruction : public HloInstruction { eq_computations) const override; // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; // The bit sizes for a reduce-precision operation. @@ -894,8 +886,7 @@ class HloInfeedInstruction : public HloInstruction { eq_computations) const override; // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; // The string representation of the infeed configuration. @@ -910,7 +901,6 @@ class HloOutfeedInstruction : public HloInstruction { absl::string_view outfeed_config); // Returns the shape for the Outfeed instruction. const Shape& outfeed_shape() const { - TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(outfeed_shape_)); return outfeed_shape_; } // Returns the config for the Outfeed instruction. @@ -927,8 +917,7 @@ class HloOutfeedInstruction : public HloInstruction { eq_computations) const override; // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; // Shape of outfeed request. @@ -941,9 +930,9 @@ class HloConvolutionInstruction : public HloInstruction { public: explicit HloConvolutionInstruction( const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, - const Window& window, + int64 feature_group_count, const Window& window, const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count); + const PrecisionConfig& precision_config); const Window& window() const override { return window_; } void set_window(const Window& window) override { window_ = window; } const ConvolutionDimensionNumbers& convolution_dimension_numbers() const { @@ -956,6 +945,16 @@ class HloConvolutionInstruction : public HloInstruction { // The number of feature groups. Must be a divisor of the input feature // dimension and output feature dimension. int64 feature_group_count() const { return feature_group_count_; } + + // Returns the information used to tell the implementation information about + // what sort of precision is requested. The meaning of the field is backend + // specific. At the moment, it is only supported for kConvolution and kDot. + // Transformations on one kDot or kConvolution to another will preserve this + // information. Transformations to other HLOs will not preserve this + // information but it is presumed that the alternate lowering is strictly + // superior. + const PrecisionConfig& precision_config() const { return precision_config_; } + string ToCategory() const override; // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; @@ -969,15 +968,18 @@ class HloConvolutionInstruction : public HloInstruction { eq_computations) const override; // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; - Window window_; - // Describes the dimension numbers used for a convolution. - ConvolutionDimensionNumbers convolution_dimension_numbers_; // The number of feature groups. Must be a divisor of the input feature // dimension and output feature dimension. int64 feature_group_count_; + // Describes the window used for a convolution. + Window window_; + // Describes the dimension numbers used for a convolution. + ConvolutionDimensionNumbers convolution_dimension_numbers_; + // Information used to communicate to the implementation about the algorithm + // used to produce results. See the documentation on precision_config(). + PrecisionConfig precision_config_; }; class HloReduceWindowInstruction : public HloInstruction { @@ -1001,8 +1003,7 @@ class HloReduceWindowInstruction : public HloInstruction { eq_computations) const override; // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; Window window_; }; @@ -1050,17 +1051,26 @@ class HloSelectAndScatterInstruction : public HloInstruction { eq_computations) const override; // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; Window window_; }; class HloCustomCallInstruction : public HloInstruction { public: - explicit HloCustomCallInstruction( - const Shape& shape, tensorflow::gtl::ArraySlice operands, - absl::string_view custom_call_target); + HloCustomCallInstruction(const Shape& shape, + absl::Span operands, + absl::string_view custom_call_target, + absl::string_view opaque); + + // Constructor for a custom call with constrained layout. 'shape' and + // 'operands_with_layout' must all have layouts. + HloCustomCallInstruction(const Shape& shape, + absl::Span operands, + absl::string_view custom_call_target, + absl::string_view opaque, + absl::Span operand_shapes_with_layout); + const Window& window() const override { CHECK(window_ != nullptr); return *window_; @@ -1080,10 +1090,25 @@ class HloCustomCallInstruction : public HloInstruction { convolution_dimension_numbers_ = absl::make_unique(dnums); } + const string& opaque() const { return opaque_; } const string& custom_call_target() const { return custom_call_target_; } + void set_feature_group_count(int64 feature_group_count) { + feature_group_count_ = feature_group_count; + } + int64 feature_group_count() const { return feature_group_count_; } // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; + // Returns whether the result and operand layouts are constrained. + bool layout_constrained() const { return layout_constrained_; } + + // Returns the shapes (with layout) of the operands. CHECKs if this custom + // call does not have constrained layouts. + const std::vector& operand_shapes_with_layout() const { + CHECK(layout_constrained()); + return operand_shapes_with_layout_; + } + private: std::vector ExtraAttributesToStringImpl( const HloPrintOptions& options) const override; @@ -1093,15 +1118,23 @@ class HloCustomCallInstruction : public HloInstruction { eq_computations) const override; // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; - // Name of a global symbol to call, only present for kCustomCall. + // Name of a global symbol to call. string custom_call_target_; + // Opaque string interpreted by the backend. + string opaque_; // Describes the window in a windowed operation such as convolution. std::unique_ptr window_; // Describes the dimension numbers used for a convolution. std::unique_ptr convolution_dimension_numbers_; + // The number of feature groups. This is used for grouped convolutions. + int64 feature_group_count_; + // Whether the result and operand layouts are constrained. + bool layout_constrained_; + // For layout-constrained custom calls, this vector holds the shape with + // layout for each operand. + std::vector operand_shapes_with_layout_; }; class HloPadInstruction : public HloInstruction { @@ -1123,8 +1156,7 @@ class HloPadInstruction : public HloInstruction { eq_computations) const override; // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; // The padding configuration that describes the edge padding and interior @@ -1134,10 +1166,10 @@ class HloPadInstruction : public HloInstruction { class HloDynamicSliceInstruction : public HloInstruction { public: - explicit HloDynamicSliceInstruction( - const Shape& shape, HloInstruction* operand, - HloInstruction* start_indices, - tensorflow::gtl::ArraySlice slice_sizes); + explicit HloDynamicSliceInstruction(const Shape& shape, + HloInstruction* operand, + HloInstruction* start_indices, + absl::Span slice_sizes); // Old methods kept for smooth subclassing transition END. // Returns the size of the slice in the given dimension for a dynamic // slice node. @@ -1159,8 +1191,7 @@ class HloDynamicSliceInstruction : public HloInstruction { eq_computations) const override; // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; // Describes the [start, start + size) range size for a dynamic slice @@ -1174,12 +1205,12 @@ class HloGatherInstruction : public HloInstruction { const Shape& shape, HloInstruction* operand, HloInstruction* start_indices, const GatherDimensionNumbers& gather_dim_numbers, - tensorflow::gtl::ArraySlice slice_sizes); + absl::Span slice_sizes); const GatherDimensionNumbers& gather_dimension_numbers() const { CHECK(gather_dimension_numbers_ != nullptr); return *gather_dimension_numbers_; } - tensorflow::gtl::ArraySlice gather_slice_sizes() const { + absl::Span gather_slice_sizes() const { return gather_slice_sizes_; } // Returns the dump string of the gather dimension numbers. @@ -1189,10 +1220,9 @@ class HloGatherInstruction : public HloInstruction { // Creates an instance of GatherDimensionNumbers. static GatherDimensionNumbers MakeGatherDimNumbers( - tensorflow::gtl::ArraySlice offset_dims, - tensorflow::gtl::ArraySlice collapsed_slice_dims, - tensorflow::gtl::ArraySlice start_index_map, - int64 index_vector_dim); + absl::Span offset_dims, + absl::Span collapsed_slice_dims, + absl::Span start_index_map, int64 index_vector_dim); private: std::vector ExtraAttributesToStringImpl( @@ -1202,8 +1232,7 @@ class HloGatherInstruction : public HloInstruction { const std::function& eq_computations) const override; std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; std::unique_ptr gather_dimension_numbers_; @@ -1228,9 +1257,9 @@ class HloScatterInstruction : public HloInstruction { // Creates an instance of ScatterDimensionNumbers. static ScatterDimensionNumbers MakeScatterDimNumbers( - tensorflow::gtl::ArraySlice update_window_dims, - tensorflow::gtl::ArraySlice inserted_window_dims, - tensorflow::gtl::ArraySlice scatter_dims_to_operand_dims, + absl::Span update_window_dims, + absl::Span inserted_window_dims, + absl::Span scatter_dims_to_operand_dims, int64 index_vector_dim); private: @@ -1242,13 +1271,117 @@ class HloScatterInstruction : public HloInstruction { eq_computations) const override; // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; std::unique_ptr scatter_dimension_numbers_; }; +class HloIotaInstruction : public HloInstruction { + public: + explicit HloIotaInstruction(const Shape& shape, int64 iota_dimension); + // Returns the dimension sizes or numbers associated with this instruction. + int64 iota_dimension() const { return iota_dimension_; } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, absl::Span new_operands, + HloCloneContext* context) const override; + + const int64 iota_dimension_; +}; + +class HloDotInstruction : public HloInstruction { + public: + // Creates a dot op with operands 'lhs' and 'rhs' with contracting and batch + // dimensions specified in 'dimension_numbers'. + explicit HloDotInstruction(const Shape& shape, HloInstruction* lhs, + HloInstruction* rhs, + const DotDimensionNumbers& dimension_numbers, + const PrecisionConfig& precision_config); + + // Returns data on the dimension numbers used for a dot operation. + const DotDimensionNumbers& dot_dimension_numbers() const { + return dot_dimension_numbers_; + } + + // Returns the information used to tell the implementation information about + // what sort of precision is requested. The meaning of the field is backend + // specific. At the moment, it is only supported for kConvolution and kDot. + // Transformations on one kDot or kConvolution to another will preserve this + // information. Transformations to other HLOs will not preserve this + // information but it is presumed that the alternate lowering is strictly + // superior. + const PrecisionConfig& precision_config() const { return precision_config_; } + + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, absl::Span new_operands, + HloCloneContext* context) const override; + // Returns the dump string of the dot dimension numbers. + string DotDimensionNumbersToString() const; + + // Describes the dimension numbers used for a dot. + DotDimensionNumbers dot_dimension_numbers_; + + // Information used to communicate to the implementation about the algorithm + // used to produce results. See the documentation on precision_config(). + PrecisionConfig precision_config_; +}; + +class HloDomainInstruction : public HloInstruction { + public: + explicit HloDomainInstruction( + const Shape& shape, HloInstruction* operand, + std::unique_ptr operand_side_metadata, + std::unique_ptr user_side_metadata); + + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + // Retrieves the operand side metadata of a kDomain instruction. + const DomainMetadata& operand_side_metadata() const { + return *operand_side_metadata_; + } + // Retrieves the user side metadata of a kDomain instruction. + const DomainMetadata& user_side_metadata() const { + return *user_side_metadata_; + } + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, absl::Span new_operands, + HloCloneContext* context) const override; + + std::unique_ptr operand_side_metadata_; + std::unique_ptr user_side_metadata_; +}; } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_ diff --git a/tensorflow/compiler/xla/service/hlo_lexer.cc b/tensorflow/compiler/xla/service/hlo_lexer.cc index 0e49d343d6a9cd09e6575dca6055e982c0bfdc07..971a9a20636c80820306d512af9e7ff4a14b79b5 100644 --- a/tensorflow/compiler/xla/service/hlo_lexer.cc +++ b/tensorflow/compiler/xla/service/hlo_lexer.cc @@ -204,7 +204,7 @@ TokKind HloLexer::LexIdentifier() { auto consumable = RegexpStringPieceFromPointers(token_start_, buf_.end()); // 'consumable' will be advanced iff its prefix matches the pattern. static LazyRE2 shape_pattern = { - R"(^(\w*\d*)\[([\d,]*)\](?:(dense|sparse)?{([\d,]+)})?)"}; + R"(^(\w*\d*)\[([\d,\s]*)\](?:(dense|sparse)?{([\d,\s]+)})?)"}; if (RE2::Consume(&consumable, *shape_pattern)) { auto status_or_shape = ShapeUtil::ParseShapeString( StringPieceFromPointers(token_start_, consumable.begin())); @@ -269,7 +269,7 @@ TokKind HloLexer::LexIdentifier() { } } - str_val_ = std::string(identifier); + str_val_ = string(identifier); return TokKind::kIdent; } @@ -306,8 +306,7 @@ TokKind HloLexer::LexNumberOrPattern() { R"([-]?((\d+|\d+[.]\d*|\d*[.]\d+)([eE][+-]?\d+))|[-]?(\d+[.]\d*|\d*[.]\d+))"}; if (RE2::Consume(&consumable, *float_pattern)) { current_ptr_ = consumable.begin(); - CHECK(absl::SimpleAtod(string(token_start_, current_ptr_).c_str(), - &decimal_val_)); + CHECK(absl::SimpleAtod(string(token_start_, current_ptr_), &decimal_val_)); return TokKind::kDecimal; } @@ -407,11 +406,7 @@ TokKind HloLexer::LexString() { absl::string_view raw = StringPieceFromPointers(token_start_ + 1, current_ptr_ - 1); string error; - // TODO(b/113077997): Change to absl::CUnescape once it works properly with - // copy-on-write std::string implementations. - if (!tensorflow::str_util::CUnescape( // non-absl ok - tensorflow::StringPiece(raw.data(), raw.size()), // non-absl ok - &str_val_, &error)) { + if (!absl::CUnescape(raw, &str_val_, &error)) { LOG(ERROR) << "Failed unescaping string: " << raw << ". error: " << error; return TokKind::kError; } diff --git a/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc b/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc index 3a1dd471c626ae9497cfcca62c30736bcdbb2b38..5bf055f3c012fef687cdc275d62efdf2d4cd5e5c 100644 --- a/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc @@ -219,6 +219,33 @@ void PropagateLivenessToParameterCallers( } } +// Makes sure that if a live instruction is within a computation used in control +// flow operations, we mark live even other related instructions. +void PropagateLivenessThroughControlFlow( + const HloInstruction* instruction, + HloLivenessAnalysis::HloIndexMap* live_index_map, Worklist* worklist, + Workset* workset, CallGraph* call_graph) { + const CallGraphNode& call_graph_node = + call_graph->GetNode(instruction->parent()); + if (call_graph_node.context() == CallContext::kSequential) { + for (const CallSite& callsite : call_graph_node.caller_callsites()) { + HloInstruction* caller = callsite.instruction(); + if (caller->opcode() == HloOpcode::kWhile) { + // If a live instruction is within the %while body or condition + // computation, mark the predicate value returned by the condition + // computation live as well. + MarkLiveAtIndex(caller->while_condition()->root_instruction(), {}, + live_index_map, worklist, workset); + } else if (caller->opcode() == HloOpcode::kConditional) { + // If a live instruction is within the true or false branches of a + // conditional, we mark the predicate operand live as well. + MarkLiveAtIndex(caller->operand(0), {}, live_index_map, worklist, + workset); + } + } + } +} + } // namespace HloLivenessAnalysis::HloLivenessAnalysis(const HloModule& module) @@ -257,12 +284,10 @@ void HloLivenessAnalysis::RunAnalysis() { } else if (instruction->opcode() == HloOpcode::kGetTupleElement) { PropagateLivenessThroughGTE(instruction, &live_index_map_, &worklist, &workset); - } else if (instruction->opcode() == HloOpcode::kWhile && - ShapeUtil::IsTuple(instruction->shape())) { + } else if (instruction->opcode() == HloOpcode::kWhile) { PropagateLivenessThroughWhile(instruction, &live_index_map_, &worklist, &workset); - } else if (instruction->opcode() == HloOpcode::kParameter && - ShapeUtil::IsTuple(instruction->shape())) { + } else if (instruction->opcode() == HloOpcode::kParameter) { PropagateLivenessToParameterCallers(instruction, &live_index_map_, &worklist, &workset, call_graph_.get()); @@ -277,6 +302,8 @@ void HloLivenessAnalysis::RunAnalysis() { MarkLiveAtAllIndices(operand, &live_index_map_, &worklist, &workset); } } + PropagateLivenessThroughControlFlow(instruction, &live_index_map_, + &worklist, &workset, call_graph_.get()); } } diff --git a/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc index 01b625c29ca2823b2a2490b30a9d4d5128b4c22e..e0ae1173c6114f0bc6ef18b2cfff9d54ccfe2faf 100644 --- a/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc @@ -398,5 +398,89 @@ TEST_F(HloLivenessAnalysisTest, WhileWithLiveTupleElements) { EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "loop_var.1"), {2})); } +TEST_F(HloLivenessAnalysisTest, WhileWithOutfeed) { + auto module = ParseHloString(R"( + HloModule OutfeedLoop + WhileBody { + body_param = (s32[]) parameter(0) + token = token[] after-all() + constant.2 = s32[] constant(2) + outfeed_tuple = (s32[]) outfeed(constant.2, token) + get-tuple-element.1 = s32[] get-tuple-element(body_param), index=0 + constant.1 = s32[] constant(1) + add = s32[] add(get-tuple-element.1, constant.1) + ROOT tuple = (s32[]) tuple(add) + } + WhileCondition { + cond_param = (s32[]) parameter(0) + get-tuple-element.3 = s32[] get-tuple-element(cond_param), index=0 + constant.2 = s32[] constant(10) + ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2) + } + ENTRY SimpleLoop { + constant.3 = s32[] constant(0) + tuple.1 = (s32[]) tuple(constant.3) + while = (s32[]) while(tuple.1), condition=WhileCondition, + body=WhileBody + ROOT rtuple = () tuple() + })") + .ValueOrDie(); + + const HloLivenessAnalysis& liveness = RunLiveness(module.get()); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "add"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.3"), {})); +} + +TEST_F(HloLivenessAnalysisTest, NestedWhileWithOutfeed) { + auto module = ParseHloString(R"( + HloModule OutfeedLoop + InnerWhileBody { + body_param = (s32[]) parameter(0) + token = token[] after-all() + constant.2 = s32[] constant(2) + outfeed_tuple = (s32[]) outfeed(constant.2, token) + get-tuple-element.1 = s32[] get-tuple-element(body_param), index=0 + constant.1 = s32[] constant(1) + add = s32[] add(get-tuple-element.1, constant.1) + ROOT tuple = (s32[]) tuple(add) + } + InnerWhileCondition { + cond_param = (s32[]) parameter(0) + get-tuple-element.3 = s32[] get-tuple-element(cond_param), index=0 + constant.2 = s32[] constant(10) + ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2) + } + OuterWhileCondition { + cond_param.2 = (s32[]) parameter(0) + get-tuple-element.5 = s32[] get-tuple-element(cond_param.2), index=0 + constant.5 = s32[] constant(5) + ROOT less-than.2 = pred[] less-than(get-tuple-element.5, constant.5) + } + OuterWhileBody { + body_param.2 = (s32[]) parameter(0) + get-tuple-element.8 = s32[] get-tuple-element(body_param.2), index=0 + constant.6 = s32[] constant(0) + tuple.2 = (s32[]) tuple(constant.6) + inner_while = (s32[]) while(tuple.2), condition=InnerWhileCondition, + body=InnerWhileBody + constant.7 = s32[] constant(1) + add.2 = s32[] add(get-tuple-element.8, constant.7) + ROOT rtuple = (s32[]) tuple(add.2) + } + ENTRY SimpleLoop { + constant.3 = s32[] constant(0) + tuple.1 = (s32[]) tuple(constant.3) + while = (s32[]) while(tuple.1), condition=OuterWhileCondition, + body=OuterWhileBody + ROOT rtuple = () tuple() + })") + .ValueOrDie(); + + const HloLivenessAnalysis& liveness = RunLiveness(module.get()); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "add"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "add.2"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.3"), {})); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_matchers.h b/tensorflow/compiler/xla/service/hlo_matchers.h index 9ace0d76e0c98420b085f30c0f0042a33b6e7583..1717770301e3666b0a1c23d20b7f2e3bac5f62e4 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers.h +++ b/tensorflow/compiler/xla/service/hlo_matchers.h @@ -179,6 +179,7 @@ HLO_MATCHER(Convolution); HLO_MATCHER(Copy); HLO_MATCHER(CrossReplicaSum); HLO_MATCHER(Divide); +HLO_MATCHER(Domain); HLO_MATCHER(DynamicSlice); HLO_MATCHER(DynamicUpdateSlice); HLO_MATCHER(Eq); @@ -188,6 +189,7 @@ HLO_MATCHER(Fusion); HLO_MATCHER(Ge); HLO_MATCHER(AfterAll); HLO_MATCHER(Gt); +HLO_MATCHER(Iota); HLO_MATCHER(Infeed); HLO_MATCHER(IsFinite); HLO_MATCHER(Le); @@ -215,6 +217,7 @@ HLO_MATCHER(Remainder); HLO_MATCHER(Reshape); HLO_MATCHER(Reverse); HLO_MATCHER(Rng); +HLO_MATCHER(Scatter); HLO_MATCHER(Select); HLO_MATCHER(SelectAndScatter); HLO_MATCHER(Send); diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.cc b/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc similarity index 64% rename from tensorflow/compiler/xla/service/hlo_scheduling.cc rename to tensorflow/compiler/xla/service/hlo_memory_scheduler.cc index 56b14f9fef930af3b8255954700b30fabb1a11de..5cee865b7ad34eded1743d9d5455bb40febf6182 100644 --- a/tensorflow/compiler/xla/service/hlo_scheduling.cc +++ b/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc @@ -13,13 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/hlo_scheduling.h" +#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h" #include #include #include #include +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/service/heap_simulator.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" @@ -30,7 +32,6 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/map_util.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" namespace xla { @@ -71,11 +72,11 @@ class ListScheduler { public: // Construct and return a memory-minimizing sequence of HLO instructions // containing the given HLO computation. - static StatusOr> Run( + static StatusOr Run( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, - const tensorflow::gtl::FlatMap& + const absl::flat_hash_map& memory_by_computation) { ListScheduler scheduler(computation, points_to_analysis, size_function, memory_by_computation); @@ -100,7 +101,7 @@ class ListScheduler { ListScheduler(const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, - const tensorflow::gtl::FlatMap& + const absl::flat_hash_map& memory_by_computation) : computation_(computation), points_to_analysis_(points_to_analysis), @@ -111,7 +112,7 @@ class ListScheduler { // LogicalBuffer is in an operand of the instruction as indicated by // points-to analysis. for (auto* instruction : computation.instructions()) { - tensorflow::gtl::FlatSet instr_uses; + absl::flat_hash_set instr_uses; for (auto* operand : instruction->operands()) { points_to_analysis.GetPointsToSet(operand).ForEachElement( [&](const ShapeIndex& /*index*/, @@ -194,13 +195,15 @@ class ListScheduler { return entry; } - // Returns the number of bytes freed if the HLO instruction is scheduled. - // If the instruction calls subcomputations, we count the memory used by the - // subcomputations as memory "defined" by the instruction. This is not - // entirely accurate, because subcomputation memory will be freed after the - // instruction finishes. But it is more accurate than not taking - // subcomputations into account at all. In the future, we may improve - // accounting for subcomputation memory (b/65409243). + // Returns the number of bytes freed *after* the HLO instruction finishes. + // The current List algorithm only considers two states for an instruction: + // right before it runs, and after it finishes. We don't represent memory + // usage during the execution of an instruction. But if the instruction calls + // subcomputations, they are only live during the instruction's execution. + // We end up counting the memory used by subcomputations as memory "defined" + // by the instruction. This is not entirely accurate, but it is more accurate + // than not taking subcomputations into account at all. In the future, we may + // improve accounting for subcomputation memory (b/65409243). int64 BytesFreedIfScheduled(const ReadyListEntry& entry) { int64 freed_bytes = 0; for (const auto& kv : entry.used_buffer_unscheduled_use_counts) { @@ -222,7 +225,18 @@ class ListScheduler { } } } - return freed_bytes - entry.bytes_defined - max_subcomputation_bytes; + int64 bytes_defined; + if (max_subcomputation_bytes > 0 && + (entry.instruction->opcode() == HloOpcode::kWhile || + entry.instruction->opcode() == HloOpcode::kCall || + entry.instruction->opcode() == HloOpcode::kConditional)) { + // The output buffer of while/call/conditional is always aliased with the + // output buffer of the root instruction in the body. Don't double count. + bytes_defined = max_subcomputation_bytes; + } else { + bytes_defined = entry.bytes_defined + max_subcomputation_bytes; + } + return freed_bytes - bytes_defined; } // Constructs the scheduling priority of the given instruction. @@ -230,13 +244,12 @@ class ListScheduler { return {BytesFreedIfScheduled(entry), entry.instruction->user_count()}; } - std::vector CreateSchedule() { - std::vector schedule; + HloInstructionSequence CreateSchedule() { + HloInstructionSequence schedule; // Populate the ready list with instructions which have no operands or // control predecessors. - tensorflow::gtl::FlatMap - unscheduled_pred_count; + absl::flat_hash_map unscheduled_pred_count; for (auto* instruction : computation_.instructions()) { // TODO(b/34466113): Replace this and above with successors() or // predecessors() when these methods are added to HloInstruction. @@ -252,8 +265,8 @@ class ListScheduler { std::multimap ready_queue; // Map of ready instructions to their iterators in ready_queue. - tensorflow::gtl::FlatMap::iterator> + absl::flat_hash_map::iterator> ready_instructions; auto add_to_ready_queue = [&](HloInstruction* inst) { @@ -263,9 +276,8 @@ class ListScheduler { }; for (auto* instruction : computation_.instructions()) { - // Instruction with no operands or control predecessors will - // not be in the map. - if (unscheduled_pred_count.count(instruction) == 0) { + if (instruction->operands().empty() && + instruction->control_predecessors().empty()) { add_to_ready_queue(instruction); } } @@ -348,21 +360,19 @@ class ListScheduler { // Computations are analyzed in post-order. When scheduling an instruction // that includes subcomputations, such as a while loop, we use this map to // look up the memory needed by subcomputations. - const tensorflow::gtl::FlatMap& + const absl::flat_hash_map& memory_by_computation_; // A map containing the LogicalBuffers that each instruction uses. - tensorflow::gtl::FlatMap> + absl::flat_hash_map> buffer_uses_; // A map containing the count of unscheduled HLOs which using a particular - // LogicalBuffer. We rely on iterator stability in this map, and that the map - // entries are std::pair's. - std::unordered_map unscheduled_use_count_; + // LogicalBuffer. + absl::flat_hash_map unscheduled_use_count_; // Set of instructions which have been scheduled. - tensorflow::gtl::FlatSet scheduled_instructions_; + absl::flat_hash_set scheduled_instructions_; }; int64 SumLogicalBufferSizes( @@ -375,12 +385,12 @@ int64 SumLogicalBufferSizes( return size; } -StatusOr> ScheduleComputationHelper( +StatusOr ScheduleComputationHelper( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, const MemorySchedulerAlgorithm& algorithm, - const tensorflow::gtl::FlatMap& + const absl::flat_hash_map& memory_by_computation) { VLOG(2) << "Computation: " << computation.name(); if (algorithm) { @@ -393,17 +403,17 @@ StatusOr> ScheduleComputationHelper( } // namespace -StatusOr> DFSMemoryScheduler( +StatusOr DFSMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, - const tensorflow::gtl::FlatMap& + const absl::flat_hash_map& memory_by_computation) { // These variables are a hack to prevent overflows. int64 cumulative_total_size = 0; - int64 total_hlos = computation.parent()->NumUniqueInstructionIds(); - tensorflow::gtl::FlatMap extra_users; - tensorflow::gtl::FlatMap total_sizes; + int64 total_hlos = computation.parent()->instruction_count(); + absl::flat_hash_map extra_users; + absl::flat_hash_map total_sizes; for (const HloInstruction* hlo : computation.MakeInstructionPostOrder()) { if (ListScheduler::IgnoreInstruction(*hlo)) { extra_users[hlo] = 0; @@ -420,7 +430,7 @@ StatusOr> DFSMemoryScheduler( points_to_analysis.GetBuffersDefinedByInstruction(hlo), size_function); total_sizes[hlo] = logical_buffer_size; cumulative_total_size += logical_buffer_size; - tensorflow::gtl::FlatSet unique_operands( + absl::flat_hash_set unique_operands( hlo->operands().begin(), hlo->operands().end()); for (const HloInstruction* operand : unique_operands) { extra_users[hlo] += extra_users[operand]; @@ -444,7 +454,7 @@ StatusOr> DFSMemoryScheduler( // Construct a total order based on DFS post-order, visiting operands in // decreasing cumulative extra user order, and next by cumulative size, with a // tiebreaker by name for determinism. - std::vector sequence; + HloInstructionSequence sequence; FunctionVisitor visitor([&sequence](HloInstruction* hlo) { sequence.push_back(hlo); return Status::OK(); @@ -464,32 +474,30 @@ StatusOr> DFSMemoryScheduler( return sequence; } // namespace xla -StatusOr> ListMemoryScheduler( +StatusOr ListMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, - const tensorflow::gtl::FlatMap& + const absl::flat_hash_map& memory_by_computation) { return ListScheduler::Run(computation, points_to_analysis, size_function, memory_by_computation); } -StatusOr> PostOrderMemoryScheduler( +StatusOr PostOrderMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, - const tensorflow::gtl::FlatMap& + const absl::flat_hash_map& memory_by_computation) { - const auto& post_order = computation.MakeInstructionPostOrder(); - return std::vector{post_order.begin(), - post_order.end()}; + return HloInstructionSequence(computation.MakeInstructionPostOrder()); } -StatusOr> DefaultMemoryScheduler( +StatusOr DefaultMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, - const tensorflow::gtl::FlatMap& + const absl::flat_hash_map& memory_by_computation) { // We try a few schedulers and choose whichever returns a lower min-memory, // not accounting for fragmentation. @@ -500,7 +508,7 @@ StatusOr> DefaultMemoryScheduler( // List wins for most of our benchmarks; postorder-based schedulers win for // some RNNs. TF_ASSIGN_OR_RETURN( - std::vector list_sequence, + HloInstructionSequence list_sequence, ListMemoryScheduler(computation, points_to_analysis, size_function, memory_by_computation)); TF_ASSIGN_OR_RETURN(const int64 list_memory, @@ -509,7 +517,7 @@ StatusOr> DefaultMemoryScheduler( size_function, &memory_by_computation)); VLOG(2) << "Min-memory list sequence: " << HumanReadableNumBytes(list_memory); - TF_ASSIGN_OR_RETURN(std::vector dfs_sequence, + TF_ASSIGN_OR_RETURN(HloInstructionSequence dfs_sequence, DFSMemoryScheduler(computation, points_to_analysis, size_function, memory_by_computation)); TF_ASSIGN_OR_RETURN(const int64 dfs_memory, @@ -519,7 +527,7 @@ StatusOr> DefaultMemoryScheduler( VLOG(2) << "Min-memory dfs sequence: " << HumanReadableNumBytes(dfs_memory); TF_ASSIGN_OR_RETURN( - std::vector post_order_sequence, + HloInstructionSequence post_order_sequence, PostOrderMemoryScheduler(computation, points_to_analysis, size_function, memory_by_computation)); TF_ASSIGN_OR_RETURN(const int64 post_order_memory, @@ -546,223 +554,61 @@ StatusOr> DefaultMemoryScheduler( } } -StatusOr ScheduleComputationsInModule( +StatusOr ScheduleModule( const HloModule& module, const LogicalBuffer::SizeFunction& size_function, const MemorySchedulerAlgorithm& algorithm) { - SequentialHloOrdering::HloModuleSequence sequence; + HloSchedule schedule(&module); TF_ASSIGN_OR_RETURN(std::unique_ptr points_to_analysis, TuplePointsToAnalysis::Run(&module)); - tensorflow::gtl::FlatMap memory_by_computation; + absl::flat_hash_map memory_by_computation; for (const auto* computation : module.MakeComputationPostOrder()) { if (!computation->IsFusionComputation()) { - TF_ASSIGN_OR_RETURN(auto one_computation_sequence, + TF_ASSIGN_OR_RETURN(HloInstructionSequence computation_sequence, ScheduleComputationHelper( *computation, *points_to_analysis, size_function, algorithm, memory_by_computation)); memory_by_computation[computation] = HeapSimulator::MinimumMemoryForComputation( - *computation, one_computation_sequence, *points_to_analysis, + *computation, computation_sequence, *points_to_analysis, size_function, &memory_by_computation) .ValueOrDie(); - sequence[computation] = std::move(one_computation_sequence); + schedule.set_sequence(computation, std::move(computation_sequence)); } } - VLOG(1) << "Module schedule:\n" << sequence; - return sequence; + VLOG(1) << "Module schedule:\n" << schedule; + + TF_RETURN_IF_ERROR(schedule.Verify()); + + return std::move(schedule); } -StatusOr> ScheduleOneComputation( +StatusOr ScheduleComputation( const HloComputation& computation, const LogicalBuffer::SizeFunction& size_function) { CHECK(!computation.IsFusionComputation()); TF_ASSIGN_OR_RETURN(std::unique_ptr points_to_analysis, TuplePointsToAnalysis::Run(computation.parent())); - tensorflow::gtl::FlatMap empty_map; + absl::flat_hash_map empty_map; return ScheduleComputationHelper(computation, *points_to_analysis, size_function, nullptr, empty_map); } -tensorflow::gtl::FlatMap> -ComputeIdSchedule(const SequentialHloOrdering::HloModuleSequence& sequence) { - tensorflow::gtl::FlatMap> id_sequence; - for (const auto& computation_sequence : sequence) { - for (const HloInstruction* instruction : computation_sequence.second) { - id_sequence[computation_sequence.first].push_back( - instruction->unique_id()); - } - } - return id_sequence; -} - -Status UpdateSchedule( - const HloModule& module, - const tensorflow::gtl::FlatMap>& - id_sequence, - SequentialHloOrdering::HloModuleSequence* sequence) { - // Map from unique ID to HloInstruction pointer for instructions in the - // module. - tensorflow::gtl::FlatMap id_to_instruction; - // Set of all HloInstructions in the schedule. - tensorflow::gtl::FlatSet ids_in_schedule; - std::vector nonfusion_computations = - module.MakeNonfusionComputations(); - for (const HloComputation* computation : nonfusion_computations) { - for (const HloInstruction* instruction : computation->instructions()) { - TF_RET_CHECK( - id_to_instruction.insert({instruction->unique_id(), instruction}) - .second); - } - for (int id : id_sequence.at(computation)) { - ids_in_schedule.insert(id); - } - } - - // Map from HloInstruction X to newly added instructions (instruction is in - // module, but not in schedule) which use X. If an instruction is not in the - // map, then it has no users which are newly added instructions. - tensorflow::gtl::FlatMap> - new_instruction_uses; - - // For each newly added instruction, this is the count of the instruction's - // operands that have not yet been scheduled. When this value reaches zero, - // then the instruction may be placed in the schedule. - tensorflow::gtl::FlatMap - unscheduled_operand_count; - // For each computation, this is the set of newly added instructions which - // have no operands. These must be handled specially and are added to the - // beginning of the schedule. - tensorflow::gtl::FlatMap> - new_zero_operand_instructions; - for (const HloComputation* computation : nonfusion_computations) { - new_zero_operand_instructions[computation] = {}; - for (const HloInstruction* instruction : computation->instructions()) { - if (ids_in_schedule.count(instruction->unique_id()) == 0) { - // This is a newly added instruction which is not in the schedule. - for (const HloInstruction* operand : instruction->operands()) { - new_instruction_uses[operand].push_back(instruction); - } - if (instruction->operands().empty()) { - new_zero_operand_instructions[computation].push_back(instruction); - } - unscheduled_operand_count[instruction] = instruction->operand_count(); - } - } - } - - // Update the schedule with the newly added instructions, and remove any - // instructions no longer in the graph. - for (const HloComputation* computation : nonfusion_computations) { - std::vector old_computation_sequence = - std::move(sequence->at(computation)); - sequence->at(computation).clear(); - - // Create a worklist of newly added instructions which are ready to be added - // to the schedule. Initialize worklist with those that have zero operands. - std::queue worklist; - for (const HloInstruction* instruction : - new_zero_operand_instructions.at(computation)) { - worklist.push(instruction); - } - - // Lambda which schedules all instructions on the worklist. - auto schedule_worklist = [&]() { - while (!worklist.empty()) { - const HloInstruction* instruction = worklist.front(); - worklist.pop(); - sequence->at(computation).push_back(instruction); - std::vector* new_users = - tensorflow::gtl::FindOrNull(new_instruction_uses, instruction); - if (new_users != nullptr) { - // This just-scheduled instruction has users which are newly added to - // the module. Update the number of unscheduled operands and push the - // newly added instruction to the worklist if it is ready to - // schedule. - for (const HloInstruction* new_user : *new_users) { - unscheduled_operand_count.at(new_user)--; - CHECK_GE(unscheduled_operand_count.at(new_user), 0); - if (unscheduled_operand_count.at(new_user) == 0) { - worklist.push(new_user); - } - } - } - } - }; - - schedule_worklist(); - for (int id : id_sequence.at(computation)) { - auto it = id_to_instruction.find(id); - if (it == id_to_instruction.end()) { - // This instruction in the schedule is no longer in the module. - continue; - } - const HloInstruction* instruction = it->second; - worklist.push(instruction); - schedule_worklist(); - } - } - - TF_RETURN_IF_ERROR(VerifySchedule(module, *sequence)); - return Status::OK(); +HloMemoryScheduler::HloMemoryScheduler( + const LogicalBuffer::SizeFunction& size_function, + const MemorySchedulerAlgorithm& algorithm) + : size_function_(size_function), algorithm_(algorithm) {} + +StatusOr HloMemoryScheduler::Run(HloModule* module) { + TF_ASSIGN_OR_RETURN(HloSchedule schedule, + ScheduleModule(*module, size_function_, algorithm_)); + TF_RETURN_IF_ERROR(module->set_schedule(std::move(schedule))); + return true; } -Status VerifySchedule( - const HloModule& module, - const SequentialHloOrdering::HloModuleSequence& sequence) { - VLOG(2) << "VerifySchedule()"; - XLA_VLOG_LINES(2, module.ToString()); - VLOG(2) << sequence; - - // Verify the set of computations in the sequence is exactly the set of - // computations in the module. - std::vector nonfusion_computations = - module.MakeNonfusionComputations(); - TF_RET_CHECK(nonfusion_computations.size() == sequence.size()); - tensorflow::gtl::FlatSet computations_in_module( - module.computations().begin(), module.computations().end()); - for (const auto& computation_sequence : sequence) { - TF_RET_CHECK(computations_in_module.count(computation_sequence.first) == 1); - } - - // For each computation verify the set of instructions is the same and that - // each dependency and control edge is honored. - for (const HloComputation* computation : nonfusion_computations) { - tensorflow::gtl::FlatMap instruction_position; - int pos = 0; - for (const HloInstruction* instruction : sequence.at(computation)) { - TF_RET_CHECK(instruction_position.insert({instruction, pos}).second) - << "Instruction " << instruction->name() - << " appears more than once in the schedule"; - pos++; - } - - TF_RET_CHECK(instruction_position.size() == - computation->instruction_count()); - for (const HloInstruction* instruction : computation->instructions()) { - TF_RET_CHECK(instruction_position.count(instruction) == 1) - << "Instruction " << instruction->name() << " is not in schedule"; - } - - for (const HloInstruction* instruction : computation->instructions()) { - for (const HloInstruction* operand : instruction->operands()) { - TF_RET_CHECK(instruction_position.at(operand) < - instruction_position.at(instruction)) - << "Instruction " << instruction->name() - << " is not scheduled after its operand " << operand->name(); - } - - for (const HloInstruction* pred : instruction->control_predecessors()) { - TF_RET_CHECK(instruction_position.at(pred) < - instruction_position.at(instruction)) - << "Instruction " << instruction->name() - << " is not scheduled after its control predecessor " - << pred->name(); - } - } - } - - return Status::OK(); +StatusOr HloDescheduler::Run(HloModule* module) { + bool changed = module->has_schedule(); + module->clear_schedule(); + return changed; } } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_memory_scheduler.h b/tensorflow/compiler/xla/service/hlo_memory_scheduler.h new file mode 100644 index 0000000000000000000000000000000000000000..a4c1d3db8170a1725043def576f913e09b352e5d --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_memory_scheduler.h @@ -0,0 +1,124 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MEMORY_SCHEDULER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MEMORY_SCHEDULER_H_ + +#include + +#include "absl/container/flat_hash_map.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_ordering.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/compiler/xla/service/hlo_schedule.h" +#include "tensorflow/compiler/xla/service/logical_buffer.h" +#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" + +namespace xla { + +// A memory scheduler computes an execution sequence for the HLO instructions in +// 'computation' that minimizes peak memory, given a points-to analysis result +// that describes buffer aliasing, together with a target-specific size function +// that maps a tensor's logical size to its padded size. +typedef std::function( + const HloComputation&, const TuplePointsToAnalysis&, + const LogicalBuffer::SizeFunction&, + const absl::flat_hash_map&)> + MemorySchedulerAlgorithm; + +// List scheduler +StatusOr ListMemoryScheduler( + const HloComputation& computation, + const TuplePointsToAnalysis& points_to_analysis, + const LogicalBuffer::SizeFunction& size_function, + const absl::flat_hash_map& + memory_by_computation); + +// DFS-order scheduler +StatusOr DFSMemoryScheduler( + const HloComputation& computation, + const TuplePointsToAnalysis& points_to_analysis, + const LogicalBuffer::SizeFunction& size_function, + const absl::flat_hash_map& + memory_by_computation); + +// Naive Post Order scheduler +StatusOr PostOrderMemoryScheduler( + const HloComputation& computation, + const TuplePointsToAnalysis& points_to_analysis, + const LogicalBuffer::SizeFunction& size_function, + const absl::flat_hash_map& + memory_by_computation); + +// The default scheduling algorithm. Runs both the list scheduler +// and the DFS scheduler, and chooses whichever returns a lower min-memory, +// not accounting for fragmentation. +StatusOr DefaultMemoryScheduler( + const HloComputation& computation, + const TuplePointsToAnalysis& points_to_analysis, + const LogicalBuffer::SizeFunction& size_function, + const absl::flat_hash_map& + memory_by_computation); + +// Returns an HloSchedule which seeks to minimize the memory required for +// the computation. size_function is the function returning the number of bytes +// required for a LogicalBuffer. +StatusOr ScheduleModule( + const HloModule& module, const LogicalBuffer::SizeFunction& size_function, + const MemorySchedulerAlgorithm& algorithm = {}); + +// Computes the schedule for a single computation. +// Currently only used by the GPU backend. +StatusOr ScheduleComputation( + const HloComputation& computation, + const LogicalBuffer::SizeFunction& size_function); + +// A pass which schedules the HLO instructions in a module. The HloModule's +// schedule field is set to the resulting HloSchedule using +// HloModule::set_schedule. +class HloMemoryScheduler : public HloModulePass { + public: + // size_function is the function returning the number of bytes required for a + // LogicalBuffer. algorithm is the memory scheduling algorithm to use. If not + // specified, then DefaultMemoryScheduler is used. + HloMemoryScheduler(const LogicalBuffer::SizeFunction& size_function, + const MemorySchedulerAlgorithm& algorithm = {}); + ~HloMemoryScheduler() override = default; + absl::string_view name() const override { return "hlo-memory-scheduler"; } + + StatusOr Run(HloModule* module) override; + + private: + LogicalBuffer::SizeFunction size_function_; + MemorySchedulerAlgorithm algorithm_; +}; + +// A trivial pass which clears the schedule currently set on the +// HloModule. After this pass runs HloModudle::has_schedule will return false. +class HloDescheduler : public HloModulePass { + public: + HloDescheduler() = default; + ~HloDescheduler() override = default; + absl::string_view name() const override { return "hlo-descheduler"; } + + StatusOr Run(HloModule* module) override; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MEMORY_SCHEDULER_H_ diff --git a/tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc b/tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..214119fba881c4411a262cd4227b5cc49cef0d14 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc @@ -0,0 +1,313 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h" + +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "tensorflow/compiler/xla/service/heap_simulator.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_dce.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_ordering.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace xla { +namespace { + +class HloSchedulingTest : public HloTestBase {}; + +TEST_F(HloSchedulingTest, LastUseScheduledFirst) { + // Tests scheduling of the following HLO code: + // + // %ab = abs(%param) + // %exp = exp(%param) + // %add = add(%ab, %exp) + // %negate = negate(%exp) + // %sub = subtract(%add, %negate) + // + // %add should be scheduled before %negate because %add is the last (and only) + // use of %ab. Scheduling %add first then frees up %ab's buffer. + const Shape vec = ShapeUtil::MakeShape(xla::F32, {42}); + auto builder = HloComputation::Builder(TestName()); + auto param = + builder.AddInstruction(HloInstruction::CreateParameter(0, vec, "param")); + auto ab = builder.AddInstruction( + HloInstruction::CreateUnary(vec, HloOpcode::kAbs, param)); + auto exp = builder.AddInstruction( + HloInstruction::CreateUnary(vec, HloOpcode::kExp, param)); + + auto add = builder.AddInstruction( + HloInstruction::CreateBinary(vec, HloOpcode::kAdd, ab, exp)); + auto negate = builder.AddInstruction( + HloInstruction::CreateUnary(vec, HloOpcode::kNegate, exp)); + auto sub = builder.AddInstruction( + HloInstruction::CreateBinary(vec, HloOpcode::kSubtract, add, negate)); + + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); + + HloMemoryScheduler scheduler([](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape()); + }); + ASSERT_FALSE(module->has_schedule()); + TF_ASSERT_OK_AND_ASSIGN(bool changed, scheduler.Run(module.get())); + EXPECT_TRUE(changed); + ASSERT_TRUE(module->has_schedule()); + TF_ASSERT_OK(module->schedule().Verify()); + + // Verify that all instructions are in the sequence. + const std::vector& sequence = + module->schedule().sequence(module->entry_computation()).instructions(); + EXPECT_EQ(module->entry_computation()->instruction_count(), sequence.size()); + + // The first instruction should be the parameter and the last the root "sub". + EXPECT_EQ(param, sequence.front()); + EXPECT_EQ(sub, sequence.back()); + + SequentialHloOrdering ordering(module->schedule()); + EXPECT_TRUE(ordering.ExecutesBefore(add, negate)); + + // Clear the schedule using the descheduling pass. + HloDescheduler descheduler; + EXPECT_TRUE(module->has_schedule()); + TF_ASSERT_OK_AND_ASSIGN(bool descheduler_changed, + descheduler.Run(module.get())); + EXPECT_TRUE(descheduler_changed); + EXPECT_FALSE(module->has_schedule()); +} + +TEST_F(HloSchedulingTest, ListSchedulerHandlesAliasing) { + const char* module_str = R"( +HloModule test_aliasing_module + +ENTRY root { + param = s32[1000] parameter(0) + p0 = s32[1000] copy(param) + p1 = s32[1000] copy(param) + t = (s32[1000], s32[1000]) tuple(p0, p1) + a = s32[1000] get-tuple-element(t), index=0 + b = s32[1000] get-tuple-element(t), index=1 + c = s32[1000] add(a, b) + d = s32[1000] add(c, b) + e = s32[1000] add(c, c) + f = s32[1000] add(e, e) + ROOT result = (s32[1000], s32[1000], s32[1000]) tuple(d, e, f) +})"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(module_str)); + + auto size_fn = [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8); + }; + TF_ASSERT_OK_AND_ASSIGN( + HloSchedule schedule, + ScheduleModule(*module, size_fn, ListMemoryScheduler)); + // Verify that all instructions are in the sequence. + const std::vector& sequence = + schedule.sequence(module->entry_computation()).instructions(); + EXPECT_EQ(module->entry_computation()->instruction_count(), sequence.size()); + + std::unordered_map instructions_by_name; + for (const HloInstruction* instruction : sequence) { + instructions_by_name[instruction->name()] = instruction; + } + + // The first instruction should be the parameter and the last the root. + EXPECT_EQ(instructions_by_name.at("param"), sequence.front()); + EXPECT_EQ(instructions_by_name.at("result"), sequence.back()); + + // Instructions "d" and "e" will both be schedulable at the same time, but + // instruction "d" allows us to free the buffer of "p1", so the list scheduler + // should prefer it. + SequentialHloOrdering ordering(schedule); + EXPECT_TRUE(ordering.ExecutesBefore(instructions_by_name.at("d"), + instructions_by_name.at("e"))); +} + +TEST_F(HloSchedulingTest, TuplesAreAccountedCorrectly) { + auto builder = HloComputation::Builder(TestName()); + const auto TUPLE_SIZE = 1; + const Shape r1f32 = ShapeUtil::MakeShape(xla::F32, {6}); + + // Wrap lit in abs because constants are considered free by + // IgnoreInstruction, and it skews the accounting. + auto lit = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1({1, 1, 1, 1, 1, 1}))); + auto abs_const = builder.AddInstruction( + HloInstruction::CreateUnary(r1f32, HloOpcode::kAbs, lit)); + + auto abs_abs1 = builder.AddInstruction( + HloInstruction::CreateUnary(r1f32, HloOpcode::kAbs, abs_const)); + auto tuple = builder.AddInstruction(HloInstruction::CreateTuple( + absl::Span({abs_abs1}))); + auto tuple_elm = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(r1f32, tuple, 0)); + + auto abs_abs2 = builder.AddInstruction( + HloInstruction::CreateUnary(r1f32, HloOpcode::kAbs, abs_const)); + + builder.AddInstruction(HloInstruction::CreateBinary(r1f32, HloOpcode::kAdd, + tuple_elm, abs_abs2)); + + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); + TF_ASSERT_OK_AND_ASSIGN(HloSchedule schedule, + ScheduleModule(*module, + [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf( + buffer.shape(), TUPLE_SIZE); + }, + ListMemoryScheduler)); + + // Verify that all instructions are in the sequence. + EXPECT_EQ(module->entry_computation()->instruction_count(), + schedule.sequence(module->entry_computation()).size()); + SequentialHloOrdering ordering(schedule); + // tuple allocates the tuple buffer and doesn't free anything. + // abs_abs2 uses the same buffer for input/output, so its bytes-freed is 0. + // abs_abs2 should be scheduled before tuple by List. + EXPECT_TRUE(ordering.ExecutesBefore(abs_abs2, tuple)); +} + +TEST_F(HloSchedulingTest, MultiOutputFusionAccountedCorrectly) { + const Shape r1f32 = ShapeUtil::MakeShape(xla::F32, {5}); + HloComputation::Builder builder(TestName()); + + auto c1 = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1({1, 1, 1, 1, 1}))); + auto c2 = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1({1, 2, 3, 4, 5}))); + auto c3 = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1({0, 2, 4, 6, 8}))); + + auto add = builder.AddInstruction( + HloInstruction::CreateBinary(r1f32, HloOpcode::kAdd, c1, c2)); + auto mul = builder.AddInstruction( + HloInstruction::CreateBinary(r1f32, HloOpcode::kMultiply, add, c3)); + auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({add, mul})); + + auto tuple_elm = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(r1f32, tuple, 0)); + + auto exp = builder.AddInstruction( + HloInstruction::CreateUnary(r1f32, HloOpcode::kExp, c3)); + + builder.AddInstruction( + HloInstruction::CreateBinary(r1f32, HloOpcode::kAdd, tuple_elm, exp)); + + auto module = CreateNewModule(); + auto* computation = module->AddEntryComputation(builder.Build()); + + auto fusion = computation->CreateFusionInstruction( + {tuple, mul, add}, HloInstruction::FusionKind::kLoop); + + TF_ASSERT_OK_AND_ASSIGN(HloSchedule schedule, + ScheduleModule(*module, + [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf( + buffer.shape(), 2); + }, + ListMemoryScheduler)); + + // Verify that all instructions are in the sequence. + EXPECT_EQ(module->entry_computation()->instruction_count(), + schedule.sequence(module->entry_computation()).size()); + SequentialHloOrdering ordering(schedule); + // fusion allocates memory for the tuple elements and doesn't free anything, + // so it's more expensive than exp. + EXPECT_TRUE(ordering.ExecutesBefore(exp, fusion)); +} + +TEST_F(HloSchedulingTest, HeapSimulatorAccountsForSubcomputations) { + auto module = CreateNewModule(); + const Shape r1f32 = ShapeUtil::MakeShape(F32, {4}); + + // param != 0 + // Needs 17 bytes + auto cond_builder = HloComputation::Builder("WhileCond"); + HloInstruction* cond_param = cond_builder.AddInstruction( + HloInstruction::CreateParameter(0, r1f32, "cond_param")); + HloInstruction* zero_vector = + cond_builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR2({{0, 0, 0, 0}}))); + cond_builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(PRED, {}), HloOpcode::kNe, cond_param, zero_vector)); + auto cond_computation = module->AddEmbeddedComputation(cond_builder.Build()); + + // param - 1 + // Needs 16 bytes + auto body_builder = HloComputation::Builder("WhileBody"); + HloInstruction* body_param = body_builder.AddInstruction( + HloInstruction::CreateParameter(0, r1f32, "body_param")); + HloInstruction* one_vector = + body_builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR2({{1, 1, 1, 1}}))); + body_builder.AddInstruction(HloInstruction::CreateBinary( + r1f32, HloOpcode::kSubtract, body_param, one_vector)); + auto body_computation = module->AddEmbeddedComputation(body_builder.Build()); + + auto builder = HloComputation::Builder(TestName()); + HloInstruction* while_init = + builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR2({{1, 1, 1, 1}}))); + // Creates 16 bytes, ignoring subcomputations + builder.AddInstruction(HloInstruction::CreateWhile( + r1f32, cond_computation, body_computation, while_init)); + + module->AddEntryComputation(builder.Build()); + + auto size_fn = [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape()); + }; + TF_ASSERT_OK_AND_ASSIGN( + HloSchedule schedule, + ScheduleModule(*module, size_fn, ListMemoryScheduler)); + // Verify that all instructions are in the sequence. + auto entry_computation = module->entry_computation(); + EXPECT_EQ(module->entry_computation()->instruction_count(), + schedule.sequence(module->entry_computation()).size()); + + absl::flat_hash_map memory_by_computation; + memory_by_computation[cond_computation] = 17; + memory_by_computation[body_computation] = 16; + std::unique_ptr points_to_analysis = + TuplePointsToAnalysis::Run(module.get()).ValueOrDie(); + + // HeapSimulator doesn't account for subcomputations + EXPECT_EQ(16, HeapSimulator::MinimumMemoryForComputation( + *entry_computation, schedule.sequence(entry_computation), + *points_to_analysis, size_fn) + .ValueOrDie()); + // HeapSimulator accounts for subcomputations. Cond is the largest one. + // The output buffer of the while is aliased. + EXPECT_EQ(17, HeapSimulator::MinimumMemoryForComputation( + *entry_computation, schedule.sequence(entry_computation), + *points_to_analysis, size_fn, &memory_by_computation) + .ValueOrDie()); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index 78167335c8efeb3de4b475bba562a8f0150a3aa6..6845c27a91845ef971dc2d82266200bfccb25533 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -23,9 +23,12 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/service/hlo_schedule.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/gtl/map_util.h" @@ -50,9 +53,16 @@ StatusOr HloModule::LaunderConstInstructionFromModule( return const_cast(hlo); } +Status HloModule::set_schedule(HloSchedule schedule) { + TF_RET_CHECK(schedule.module() == this); + TF_RETURN_IF_ERROR(schedule.Verify()); + schedule_ = std::move(schedule); + return Status::OK(); +} + HloComputation* HloModule::AddComputationInternal( std::unique_ptr computation, bool is_entry, - bool uniquify_names) { + bool uniquify_identifiers) { if (is_entry) { CHECK_EQ(nullptr, entry_computation_); entry_computation_ = computation.get(); @@ -63,32 +73,40 @@ HloComputation* HloModule::AddComputationInternal( config_.SetDefaultComputationLayout( entry_computation_->ComputeProgramShape()); } + input_output_alias_config_ = HloInputOutputAliasConfig( + entry_computation_->root_instruction()->shape()); } - if (uniquify_names) { + if (uniquify_identifiers) { computation->UniquifyName(&computation_name_uniquer_); for (auto* instruction : computation->instructions()) { instruction->UniquifyName(&instruction_name_uniquer_); } + + // Pick unique IDs for each instruction. + for (auto* instruction : computation->instructions()) { + instruction->SetUniqueId(NewUniqueInstructionId()); + } + // Set unique id to this computation. + CHECK_NE(computation->root_instruction()->unique_id(), -1) + << "Root has no valid id: " << computation->ToString(); + computation->SetUniqueId(computation->root_instruction()->unique_id()); } else { // Don't uniquify the names of the computation or instruction, but we must // run the names through the uniquifiers to prevent future name collisions - // for computations and instructions created later. + // for computations and instructions created later. Also, set the + // next_unique_id_ to the one greater than the max unique id of any + // instruction (or the computation) to avoid ID collisions. computation_name_uniquer_.GetUniqueName(computation->name()); for (auto* instruction : computation->instructions()) { instruction_name_uniquer_.GetUniqueName(instruction->name()); + next_unique_id_ = std::max(next_unique_id_, instruction->unique_id() + 1); + } + if (next_unique_id_ < computation->unique_id() + 1) { + next_unique_id_ = computation->unique_id() + 1; } } - // Pick unique IDs for each instruction. - for (auto* instruction : computation->instructions()) { - instruction->SetUniqueId(NewUniqueInstructionId()); - } - // Set unique id to this computation. - CHECK_NE(computation->root_instruction()->unique_id(), -1) - << "Root has no valid id: " << computation->ToString(); - computation->SetUniqueId(computation->root_instruction()->unique_id()); - computation->set_parent(this); computations_.push_back(std::move(computation)); return computations_.back().get(); @@ -97,7 +115,7 @@ HloComputation* HloModule::AddComputationInternal( HloComputation* HloModule::AddEntryComputation( std::unique_ptr computation) { return AddComputationInternal(std::move(computation), /*is_entry=*/true, - /*uniquify_names=*/true); + /*uniquify_identifiers=*/true); } Status HloModule::RemoveEmbeddedComputation(HloComputation* to_remove) { @@ -114,7 +132,7 @@ Status HloModule::RemoveEmbeddedComputation(HloComputation* to_remove) { HloComputation* HloModule::AddEmbeddedComputation( std::unique_ptr computation) { return AddComputationInternal(std::move(computation), /*is_entry=*/false, - /*uniquify_names=*/true); + /*uniquify_identifiers=*/true); } void HloModule::ReplaceComputations( @@ -130,7 +148,8 @@ void HloModule::ReplaceComputations( case HloOpcode::kCall: case HloOpcode::kMap: case HloOpcode::kReduce: - case HloOpcode::kReduceWindow: { + case HloOpcode::kReduceWindow: + case HloOpcode::kScatter: { HloComputation* new_arg = tensorflow::gtl::FindWithDefault( replacements, instruction->to_apply(), nullptr); if (new_arg != nullptr) { @@ -198,12 +217,23 @@ void HloModule::ReplaceComputations( string HloModule::ToString(const HloPrintOptions& options) const { std::ostringstream s; - s << "HloModule " << name() << "\n\n"; + s << "HloModule " << name(); + if (has_schedule()) { + TF_CHECK_OK(schedule().Verify()); + s << ", is_scheduled=true"; + } + s << "\n\n"; for (const HloComputation* computation : MakeComputationPostOrder()) { if (computation == entry_computation()) { s << "ENTRY "; } - s << computation->ToString(options) << "\n\n"; + if (has_schedule() && schedule().is_computation_scheduled(computation)) { + s << computation->ToString( + options, schedule().sequence(computation).instructions()) + << "\n\n"; + } else { + s << computation->ToString(options) << "\n\n"; + } } return s.str(); } @@ -216,22 +246,28 @@ HloModuleProto HloModule::ToProto() const { proto.set_entry_computation_id(entry_computation_->unique_id()); for (const HloComputation* computation : MakeComputationPostOrder()) { HloComputationProto computation_proto = computation->ToProto(); - if (computation->name() == entry_computation_->name()) { - *proto.mutable_program_shape() = computation_proto.program_shape(); - } proto.add_computations()->Swap(&computation_proto); } + if (has_schedule()) { + *proto.mutable_schedule() = schedule().ToProto().ValueOrDie(); + } + *proto.mutable_host_program_shape() = + entry_computation_layout().ComputeProgramShape(); + *proto.mutable_input_output_alias() = input_output_alias_config().ToProto(); return proto; } /* static */ StatusOr> HloModule::CreateFromProto( const HloModuleProto& proto, const HloModuleConfig& module_config) { + VLOG(2) << "CreateFromProto()"; + XLA_VLOG_LINES(2, proto.DebugString()); + // The ProgramShape in the passed in module config must match the shapes of // the entry parameters and root. - TF_RET_CHECK(proto.has_program_shape()) + TF_RET_CHECK(proto.has_host_program_shape()) << "No program shape found in the proto"; - const auto& expected_program_shape = proto.program_shape(); + const auto& expected_program_shape = proto.host_program_shape(); TF_RET_CHECK(expected_program_shape.parameters_size() == module_config.entry_computation_layout().parameter_count()); for (int i = 0; i < expected_program_shape.parameters_size(); ++i) { @@ -254,8 +290,8 @@ StatusOr> HloModule::CreateFromProto( << ShapeUtil::HumanStringWithLayout(expected_program_shape.result()) << ", actual: " << ShapeUtil::HumanStringWithLayout(result_shape); - tensorflow::gtl::FlatMap computation_map; - tensorflow::gtl::FlatMap to_proto_id; + absl::flat_hash_map computation_map; + absl::flat_hash_map to_proto_id; std::vector> computations; HloComputation* entry = nullptr; for (const HloComputationProto& computation_proto : proto.computations()) { @@ -290,34 +326,55 @@ StatusOr> HloModule::CreateFromProto( // Don't uniquify names because we want names to be stable across // serialization and deserialization. module->AddComputationInternal(std::move(computation), is_entry, - /*uniquify_names=*/false); + /*uniquify_identifiers=*/false); } TF_RET_CHECK(module->entry_computation_ != nullptr); - // Because we didn't uniquify the names, double-check that the instruction and - // computation names are unique from the proto. - tensorflow::gtl::FlatSet computation_names; - tensorflow::gtl::FlatSet instruction_names; + TF_ASSIGN_OR_RETURN(module->input_output_alias_config_, + HloInputOutputAliasConfig::CreateFromProto( + result_shape, proto.input_output_alias())); + + // Because we didn't uniquify the names or the ids, double-check that the + // instruction and computation names and ids are unique from the proto. + absl::flat_hash_set computation_names; + absl::flat_hash_set instruction_names; + absl::flat_hash_set computation_ids; + absl::flat_hash_set instruction_ids; for (HloComputation* computation : module->computations()) { TF_RET_CHECK(!ContainsKey(computation_names, computation->name())) << "Computation name is not unique: " << computation->name(); computation_names.insert(computation->name()); + + TF_RET_CHECK(!ContainsKey(computation_ids, computation->unique_id())) + << "Computation id is not unique: " << computation->unique_id(); + computation_ids.insert(computation->unique_id()); for (HloInstruction* instruction : computation->instructions()) { TF_RET_CHECK(!ContainsKey(instruction_names, instruction->name())) << "Instruction name is not unique: " << instruction->name(); instruction_names.insert(instruction->name()); + + TF_RET_CHECK(!ContainsKey(instruction_ids, instruction->unique_id())) + << "Instruction id is not unique: " << instruction->unique_id(); + instruction_ids.insert(instruction->unique_id()); } } + if (proto.has_schedule()) { + TF_ASSIGN_OR_RETURN( + HloSchedule schedule, + HloSchedule::CreateFromProto(module.get(), proto.schedule())); + TF_RETURN_IF_ERROR(module->set_schedule(std::move(schedule))); + } + return std::move(module); } /* static */ StatusOr HloModule::CreateModuleConfigFromProto( const HloModuleProto& module, const DebugOptions& debug_options) { - TF_RET_CHECK(module.has_program_shape()) + TF_RET_CHECK(module.has_host_program_shape()) << "No program shape found in the proto"; - const auto& program_shape = module.program_shape(); + const auto& program_shape = module.host_program_shape(); HloModuleConfig module_config(program_shape); module_config.set_debug_options(debug_options); @@ -353,7 +410,7 @@ bool IsUsedOutsideSubcomputation( } // anonymous namespace HloInstruction* HloModule::OutlineExpressionFromComputation( - tensorflow::gtl::ArraySlice instructions_to_outline, + absl::Span instructions_to_outline, const string& outlined_computation_name, HloComputation* computation) { auto builder = HloComputation::Builder(outlined_computation_name); @@ -507,8 +564,13 @@ std::vector HloModule::MakeNonfusionComputations() const { } std::unique_ptr HloModule::Clone(const string& suffix) const { + return Clone(config(), suffix); +} + +std::unique_ptr HloModule::Clone(const HloModuleConfig& config, + const string& suffix) const { VLOG(1) << "Cloning module :" << name_ << " --> " << suffix << "\n"; - auto module = absl::make_unique(name_ + "-" + suffix, config_); + auto module = absl::make_unique(name_ + "-" + suffix, config); HloCloneContext context(module.get(), suffix); auto cloned_computation = entry_computation_->Clone(suffix, &context); diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h index cf129b835db56c21245c7e98d7e7876c1e507132..5dc795fabec5d8d794635ef6965c4d065b0b75a6 100644 --- a/tensorflow/compiler/xla/service/hlo_module.h +++ b/tensorflow/compiler/xla/service/hlo_module.h @@ -25,15 +25,18 @@ limitations under the License. #include #include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/iterator_util.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_clone_context.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" +#include "tensorflow/compiler/xla/service/hlo_schedule.h" #include "tensorflow/compiler/xla/service/name_uniquer.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/gtl/iterator_range.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/mutex.h" @@ -61,6 +64,7 @@ class HloModule { // tests). The versioned handle is used by the service in the compilation // cache. A default configuration is created for this module. explicit HloModule(const string& name, const HloModuleConfig& config); + virtual ~HloModule() {} // Adds an entry computation to the module. A module can only have one entry // computation. Returns a pointer to the newly added computation. @@ -85,9 +89,12 @@ class HloModule { const std::unordered_map& replacements); const string& name() const { return name_; } + void set_name(string name) { name_ = std::move(name); } // Returns a deep copy of this module including all computations. std::unique_ptr Clone(const string& suffix = "clone") const; + std::unique_ptr Clone(const HloModuleConfig& config, + const string& suffix = "clone") const; // Performs a deep clone of the computation, by recursively cloning all // the called computations as well. If the clone context is specified, it @@ -95,7 +102,7 @@ class HloModule { HloComputation* DeepCloneComputation(HloComputation* computation, HloCloneContext* context = nullptr); - // Return a pointer to the entry computation of the module.. + // Return a pointer to the entry computation of the module. const HloComputation* entry_computation() const { CHECK_NE(nullptr, entry_computation_); return entry_computation_; @@ -105,6 +112,14 @@ class HloModule { return entry_computation_; } + // Returns the root instruction shape of entry computation. + // + // Precondition: entry_computation_ is not nullptr. + const Shape& result_shape() const { + CHECK_NE(nullptr, entry_computation_); + return entry_computation()->root_instruction()->shape(); + } + // Creates the ComputationLayout which describes the current status of the HLO // module entry computation. ComputationLayout compute_computation_layout() const { @@ -192,7 +207,7 @@ class HloModule { // order (root of outlined instructions last). TODO(jingyue): takes a set of // instructions and topologically sorts them. HloInstruction* OutlineExpressionFromComputation( - tensorflow::gtl::ArraySlice instructions_to_outline, + absl::Span instructions_to_outline, const string& outlined_computation_name, HloComputation* computation); // Returns a randomly generated uint64. @@ -208,9 +223,14 @@ class HloModule { return result; } - // Returns the number of unique intruction ids given out. All ids up to - // this point are guaranteed to be in the range [0..NumUniqueInstructionIds()) - int NumUniqueInstructionIds() const { return next_unique_id_; } + // input_output_alias_config indicates the list of aliased buffers that are + // expected from the module. + HloInputOutputAliasConfig& input_output_alias_config() { + return input_output_alias_config_; + } + const HloInputOutputAliasConfig& input_output_alias_config() const { + return input_output_alias_config_; + } // Returns an id that is unique to this module across all modules created over // the lifetime of this process. @@ -235,12 +255,25 @@ class HloModule { StatusOr LaunderConstInstructionFromModule( const HloInstruction* hlo); + // Sets the schedule of the module to the given schedule. + Status set_schedule(HloSchedule schedule); + + // Clears the schedule of the module. + void clear_schedule() { schedule_.reset(); } + + // Returns true if the module has a schedule set. + bool has_schedule() const { return schedule_.has_value(); } + + // Returns the schedue of the module. CHECK fails if no schedule is set. + const HloSchedule& schedule() const { return *schedule_; } + HloSchedule& schedule() { return *schedule_; } + private: HloComputation* AddComputationInternal( std::unique_ptr computation, bool is_entry, - bool uniquify_names); + bool uniquify_identifiers); - const string name_; + string name_; HloModuleConfig config_; HloComputation* entry_computation_ = nullptr; std::vector> computations_; @@ -262,6 +295,15 @@ class HloModule { static std::atomic next_unique_module_id_; // A unique id to label modules with. int unique_id_; + + // The HloSchedule of the module. The schedule if it exists contains a + // sequential order of instructions for each non-fusion computation in the + // module. + absl::optional schedule_; + + // alias_config indicates the alias information of input/output buffers that + // are expected from the module. + HloInputOutputAliasConfig input_output_alias_config_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_module_config.h b/tensorflow/compiler/xla/service/hlo_module_config.h index 3f1e1cc73eeb9debe5eb6278ab192fdf9b8cc10f..68c18836eb01484b819e7b7bd26f099dcf56e7ba 100644 --- a/tensorflow/compiler/xla/service/hlo_module_config.h +++ b/tensorflow/compiler/xla/service/hlo_module_config.h @@ -106,9 +106,6 @@ class HloModuleConfig { absl::optional entry_computation_layout_; - // Whether this is a 'host module'. - bool is_host_module_ = false; - // Module/graph-level seed handle. uint64 seed_ = 0; diff --git a/tensorflow/compiler/xla/service/hlo_module_dce.cc b/tensorflow/compiler/xla/service/hlo_module_dce.cc index 98d20315e399c6b1a3979b5d11a89ef93869f4d9..31d26cc51e8217234526bbfeb83510aadf2c27b5 100644 --- a/tensorflow/compiler/xla/service/hlo_module_dce.cc +++ b/tensorflow/compiler/xla/service/hlo_module_dce.cc @@ -36,23 +36,6 @@ namespace xla { namespace { -bool HasSendRecv(HloComputation* computation) { - for (auto* instruction : computation->instructions()) { - if (instruction->opcode() == HloOpcode::kSend || - instruction->opcode() == HloOpcode::kSendDone || - instruction->opcode() == HloOpcode::kRecv || - instruction->opcode() == HloOpcode::kRecvDone) { - return true; - } - for (auto* sub_computation : instruction->called_computations()) { - if (HasSendRecv(sub_computation)) { - return true; - } - } - } - return false; -} - StatusOr RunWhileDCE(HloModule* module, HloLivenessAnalysis* liveness) { bool changed = false; for (auto* computation : module->computations()) { @@ -67,10 +50,9 @@ StatusOr RunWhileDCE(HloModule* module, HloLivenessAnalysis* liveness) { auto* while_body_root = while_body_comp->root_instruction(); if (!ShapeUtil::IsTuple(xla_while->shape()) || - while_body_root->opcode() != HloOpcode::kTuple || - HasSendRecv(while_body_comp)) { + while_body_root->opcode() != HloOpcode::kTuple) { // Only run DCE on tuple-shaped while loops where body root is Tuple, - // with no send/recv instructions. + // with no I/O instructions. VLOG(1) << "WhileDCE SKIP while: " << xla_while->ToString(); continue; } diff --git a/tensorflow/compiler/xla/service/hlo_module_dce.h b/tensorflow/compiler/xla/service/hlo_module_dce.h index 12ca2340a6ccaa50780e81168c755c1fec3aa1be..d472211d2af6e4b583d3815146ba8cee5c8e7495 100644 --- a/tensorflow/compiler/xla/service/hlo_module_dce.h +++ b/tensorflow/compiler/xla/service/hlo_module_dce.h @@ -28,7 +28,7 @@ namespace xla { // Sweeps through live instructions which cross computation boundaries (kWhile), // and removes code at dead shape indices. // -class HloModuleDCE : public HloPassInterface { +class HloModuleDCE : public HloModulePass { public: ~HloModuleDCE() override {} absl::string_view name() const override { return "hlo-module-dce"; } diff --git a/tensorflow/compiler/xla/service/hlo_module_dce_test.cc b/tensorflow/compiler/xla/service/hlo_module_dce_test.cc index 363862e4905fc13a4ef07aeaac255259fc6b86ba..bf66cc6bc37a5e11c9ecfc07a62ba0ea5ca11a03 100644 --- a/tensorflow/compiler/xla/service/hlo_module_dce_test.cc +++ b/tensorflow/compiler/xla/service/hlo_module_dce_test.cc @@ -367,5 +367,77 @@ TEST_F(HloModuleDceTest, TwoWhilesWithDeadTupleElementSwizzled) { "while.2", 1)); } +// Tests that a while whose body has outfeed operations is not DCE-ed. +TEST_F(HloModuleDceTest, WhileWithOutfeed) { + auto module = ParseHloString(R"( + HloModule OutfeedLoop + WhileBody { + body_param = (s32[]) parameter(0) + token = token[] after-all() + constant.2 = s32[] constant(2) + outfeed_tuple = (s32[]) outfeed(constant.2, token) + get-tuple-element.1 = s32[] get-tuple-element(body_param), index=0 + constant.1 = s32[] constant(1) + add = s32[] add(get-tuple-element.1, constant.1) + ROOT tuple = (s32[]) tuple(add) + } + WhileCondition { + cond_param = (s32[]) parameter(0) + get-tuple-element.3 = s32[] get-tuple-element(cond_param), index=0 + constant.2 = s32[] constant(10) + ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2) + } + ENTRY SimpleLoop { + constant.3 = s32[] constant(0) + tuple.1 = (s32[]) tuple(constant.3) + while = (s32[]) while(tuple.1), condition=WhileCondition, + body=WhileBody + ROOT rtuple = () tuple() + })") + .ValueOrDie(); + + HloModuleDCE dce; + EXPECT_FALSE(dce.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while", 0)); +} + +// Tests that if a loop variable is not referenced outside of a kWhile, the loop +// variable changes are not elided within the loop body, if the condition +// computation uses them. +TEST_F(HloModuleDceTest, WhileWithOnlyLoopVariableBumping) { + auto module = ParseHloString(R"( + HloModule InfiniteLoop + WhileBody { + body_param = (s32[], s32[]) parameter(0) + get-tuple-element.1 = s32[] get-tuple-element(body_param), index=0 + get-tuple-element.2 = s32[] get-tuple-element(body_param), index=1 + constant.1 = s32[] constant(1) + add = s32[] add(get-tuple-element.1, constant.1) + ROOT tuple = (s32[], s32[]) tuple(add, get-tuple-element.2) + } + WhileCondition { + cond_param = (s32[], s32[]) parameter(0) + get-tuple-element.3 = s32[] get-tuple-element(cond_param), index=0 + constant.2 = s32[] constant(10) + ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2) + } + ENTRY SimpleLoop { + p0 = (s32[]) parameter(0) + get-tuple-element.5 = s32[] get-tuple-element(p0), index=0 + constant.3 = s32[] constant(0) + tuple.1 = (s32[], s32[]) tuple(constant.3, get-tuple-element.5) + while = (s32[], s32[]) while(tuple.1), condition=WhileCondition, + body=WhileBody + ROOT get-tuple-element.4 = s32[] get-tuple-element(while), index=1 + })") + .ValueOrDie(); + + HloModuleDCE dce; + EXPECT_FALSE(dce.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while", 0)); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_module_group.cc b/tensorflow/compiler/xla/service/hlo_module_group.cc new file mode 100644 index 0000000000000000000000000000000000000000..8999ac9f324ed24cf34ef6826000e1fa4f741e19 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_module_group.cc @@ -0,0 +1,90 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_module_group.h" + +namespace xla { + +HloModuleGroup::HloModuleGroup(std::unique_ptr module) + : name_(module->name()) { + push_back(std::move(module)); +} + +HloModuleGroup::HloModuleGroup(absl::string_view name, + absl::Span> modules) + : name_(name) { + for (auto& module : modules) { + push_back(std::move(module)); + } +} + +std::vector> HloModuleGroup::ConsumeModules() { + std::vector> ret_modules = std::move(modules_); + + // Clear everything so the object state is in a known (empty) state. + modules_.clear(); + module_ptrs_.clear(); + return ret_modules; +} + +string HloModuleGroup::ToString() const { + std::ostringstream s; + s << "HloModuleGroup " << name() << "\n\n"; + for (const HloModule* module : modules()) { + s << module->ToString() << "\n"; + } + return s.str(); +} + +HloModuleGroupProto HloModuleGroup::ToProto() const { + HloModuleGroupProto proto; + proto.set_name(name()); + for (const HloModule* module : modules()) { + *proto.add_hlo_modules() = module->ToProto(); + } + return proto; +} + +/* static */ StatusOr HloModuleGroup::CreateFromProto( + const HloModuleGroupProto& proto, + absl::Span module_configs) { + TF_RET_CHECK(!proto.name().empty()) << "Module group name cannot be empty"; + TF_RET_CHECK(proto.hlo_modules_size() > 0) + << "Module group must have at least one HLO module"; + TF_RET_CHECK(proto.hlo_modules_size() == module_configs.size()); + + std::vector> modules; + for (int i = 0; i < proto.hlo_modules_size(); ++i) { + const HloModuleProto& module_proto = proto.hlo_modules(i); + TF_ASSIGN_OR_RETURN( + std::unique_ptr module, + HloModule::CreateFromProto(module_proto, module_configs[i])); + modules.push_back(std::move(module)); + } + + return HloModuleGroup(proto.name(), absl::MakeSpan(modules)); +} + +void HloModuleGroup::push_back(std::unique_ptr module) { + modules_.push_back(std::move(module)); + module_ptrs_.push_back(modules_.back().get()); +} + +std::ostream& operator<<(std::ostream& out, const HloModuleGroup& group) { + out << group.ToString(); + return out; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_module_group.h b/tensorflow/compiler/xla/service/hlo_module_group.h new file mode 100644 index 0000000000000000000000000000000000000000..7c39cf17815aa08742e6d5b35941d8043531d034 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_module_group.h @@ -0,0 +1,92 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_GROUP_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_GROUP_H_ + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" + +namespace xla { + +// An abstraction representing a ordered set of HLO module built to run +// concurrently across different devices. +class HloModuleGroup { + public: + // Construct an empty module group. + explicit HloModuleGroup(absl::string_view name) : name_(name) {} + + // Construct a module group containing a single module. + explicit HloModuleGroup(std::unique_ptr module); + + // Construct a module group containing any number of modules. + HloModuleGroup(absl::string_view name, + absl::Span> modules); + + // Returns the modules contained in the group. + const std::vector& modules() const { return module_ptrs_; } + + // Returns a module at a particular index. + HloModule& module(int index) const { return *module_ptrs_.at(index); } + + // Add a module to the back of vector of modules in the group. + void push_back(std::unique_ptr module); + + // Replaces the existing module at the given index with the given module. The + // existing module is discarded. + void ReplaceModule(int index, std::unique_ptr module); + + // Moves all modules from the group into the returned vector. After this + // method runs, the module group will be empty. + std::vector> ConsumeModules(); + + string name() const { return name_; } + + string ToString() const; + + // Serialize the module group to/from a proto. + HloModuleGroupProto ToProto() const; + static StatusOr CreateFromProto( + const HloModuleGroupProto& proto, + absl::Span module_configs); + + // Returns the number of modules in the module group. + int size() const { return modules_.size(); } + + // Returns true if there are no modules in the module group. + bool empty() const { return modules_.empty(); } + + private: + string name_; + + // Vector of modules as std::unique_ptrs. + std::vector> modules_; + + // Vector of modules as normal pointers. This vector is kept in sync with + // modules_ as modules are added to the group with push_back. + std::vector module_ptrs_; +}; + +std::ostream& operator<<(std::ostream& out, const HloModuleGroup& group); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_GROUP_H_ diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc index f52a37bc7426ea6f1cf8754d9ee8db98b1493f15..b4aac4c8076cb69647d42c6243bc969d06d0709e 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc +++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc @@ -22,6 +22,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" +#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/util.h" @@ -58,7 +59,7 @@ string HloModuleGroupMetadata::TrackedInstruction::ToString() const { } /* static */ StatusOr> -HloModuleGroupMetadata::Build(const std::vector& modules) { +HloModuleGroupMetadata::Build(absl::Span modules) { auto metadata = absl::make_unique(modules); TF_RETURN_IF_ERROR(metadata->Build()); return std::move(metadata); @@ -131,6 +132,14 @@ Status HloModuleGroupMetadata::Build() { if (VLOG_IS_ON(4)) { DumpCollectedStats(); } + + for (HloModule* module : modules_) { + TF_ASSIGN_OR_RETURN( + std::unique_ptr points_to_analysis, + TuplePointsToAnalysis::Run(module)); + points_to_analyses_[module] = std::move(points_to_analysis); + } + return Status::OK(); } @@ -163,7 +172,7 @@ Status HloModuleGroupMetadata::VerifyCompanionSets() const { ss << " " << hlo->name() << std::endl; } ss << "has multiple instructions on the same device"; - return FailedPrecondition("%s", ss.str().c_str()); + return FailedPrecondition("%s", ss.str()); } } } @@ -383,22 +392,28 @@ Status HloModuleGroupMetadata::AddCompanion(HloInstruction* instruction1, if (!ContainsKey(companion_set_index_, instruction1) && !ContainsKey(companion_set_index_, instruction2)) { companion_sets_.push_back( - absl::make_unique>()); + absl::make_unique>()); auto companion_set = companion_sets_.back().get(); - companion_set->insert(instruction1); - companion_set->insert(instruction2); + companion_set->push_back(instruction1); + companion_set->push_back(instruction2); companion_set_index_[instruction1] = companion_sets_.size() - 1; companion_set_index_[instruction2] = companion_sets_.size() - 1; } else if (!ContainsKey(companion_set_index_, instruction1)) { - companion_sets_[companion_set_index_[instruction2]]->insert(instruction1); + companion_sets_[companion_set_index_[instruction2]]->push_back( + instruction1); companion_set_index_[instruction1] = companion_set_index_[instruction2]; } else if (!ContainsKey(companion_set_index_, instruction2)) { - companion_sets_[companion_set_index_[instruction1]]->insert(instruction2); + companion_sets_[companion_set_index_[instruction1]]->push_back( + instruction2); companion_set_index_[instruction2] = companion_set_index_[instruction1]; } else if (companion_set_index_[instruction1] != companion_set_index_[instruction2]) { - companion_sets_[companion_set_index_[instruction1]]->insert( - Companions(instruction2).begin(), Companions(instruction2).end()); + // At any point while building the companion sets, each instruction belongs + // to at most 1 companion set, so the union of two companion sets is + // concatenating two disjoint sets. + absl::c_copy(Companions(instruction2), + std::back_inserter( + *companion_sets_[companion_set_index_[instruction1]])); int64 index_to_remove = companion_set_index_[instruction2]; for (HloInstruction* hlo : Companions(instruction2)) { companion_set_index_[hlo] = companion_set_index_[instruction1]; @@ -411,16 +426,16 @@ Status HloModuleGroupMetadata::AddCompanion(HloInstruction* instruction1, Status HloModuleGroupMetadata::VerifyChannelInstructions() { for (const Channel& channel : channels_) { if (channel.send == nullptr) { - return FailedPrecondition("missing send for id : %lld", channel.id); + return FailedPrecondition("missing send for id : %d", channel.id); } if (channel.recv == nullptr) { - return FailedPrecondition("missing recv for id : %lld", channel.id); + return FailedPrecondition("missing recv for id : %d", channel.id); } if (channel.send_done == nullptr) { - return FailedPrecondition("missing send-done for id : %lld", channel.id); + return FailedPrecondition("missing send-done for id : %d", channel.id); } if (channel.recv_done == nullptr) { - return FailedPrecondition("missing recv-done for id : %lld", channel.id); + return FailedPrecondition("missing recv-done for id : %d", channel.id); } } @@ -436,33 +451,33 @@ Status HloModuleGroupMetadata::VerifyChannelInstructions() { auto send_done_device = GetInstructionDevice(*channel.send_done); if (!send_device) { return FailedPrecondition("send instruction must have a device: %s", - channel.send->ToString().c_str()); + channel.send->ToString()); } if (!send_done_device) { return FailedPrecondition("send_done instruction must have a device: %s", - channel.send_done->ToString().c_str()); + channel.send_done->ToString()); } if (*send_device != *send_done_device) { return FailedPrecondition( - "send and send-done (channel=%lld) must be on the same device: %lld " - "vs. %lld", + "send and send-done (channel=%d) must be on the same device: %d " + "vs. %d", channel.id, *send_device, *send_done_device); } auto recv_device = GetInstructionDevice(*channel.recv); auto recv_done_device = GetInstructionDevice(*channel.recv_done); if (!recv_done_device) { return FailedPrecondition("recv_done instruction must have a device: %s", - channel.recv_done->ToString().c_str()); + channel.recv_done->ToString()); } if (*recv_device != *recv_done_device) { return FailedPrecondition( - "recv and recv-done (channel=%lld) must be on the same device: %lld " - "vs. %lld", + "recv and recv-done (channel=%d) must be on the same device: %d " + "vs. %d", channel.id, *recv_device, *recv_done_device); } if (*send_device == *recv_device) { return FailedPrecondition( - "send and recv (channel=%lld) must be on different devices: %lld", + "send and recv (channel=%d) must be on different devices: %d", channel.id, *send_device); } } @@ -483,7 +498,7 @@ Status HloModuleGroupMetadata::VerifyChannelInstructions() { !CheckCompanionPathsCompatibility( path, GetCompanionsPath(channel.recv_done))) { return FailedPrecondition( - "Nest companion paths do not match for channel %lld", channel.id); + "Nest companion paths do not match for channel %d", channel.id); } } return Status::OK(); diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h index dead6d9c2090c2f296788bbb97dbd7edc4ce4392..928df0f5a7444ad877961a5de970c752e1d024da 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h +++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h @@ -22,14 +22,15 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" #include "absl/types/optional.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -101,14 +102,14 @@ class HloModuleGroupMetadata { HloInstruction* recv_done = nullptr; }; - explicit HloModuleGroupMetadata(const std::vector& modules) - : modules_(modules) {} + explicit HloModuleGroupMetadata(absl::Span modules) + : modules_(modules.begin(), modules.end()) {} ~HloModuleGroupMetadata() = default; // Build and return the metadata for the given modules. static StatusOr> Build( - const std::vector& modules); + absl::Span modules); // Returns true if the instruction is one of the 4 channel instructions (Send, // Recv, SendDone, RecvDone). @@ -168,14 +169,14 @@ class HloModuleGroupMetadata { // Returns the companion instructions for the given instruction. // // Precondition: IsCompanionWhile(instruction) is true. - const std::unordered_set& Companions( + const std::vector& Companions( const HloInstruction* instruction) const { CHECK_EQ(companion_set_index_.count(instruction), 1); return companion_set(companion_set_index_.at(instruction)); } // Returns the companion set at the given index. - const std::unordered_set& companion_set(int64 index) const { + const std::vector& companion_set(int64 index) const { CHECK_LT(index, companion_sets_.size()); return *companion_sets_[index]; } @@ -186,7 +187,7 @@ class HloModuleGroupMetadata { } // Returns the list of all companion sets in the HLO module group. - const std::vector>>& + const std::vector>>& companion_sets() const { return companion_sets_; } @@ -197,6 +198,10 @@ class HloModuleGroupMetadata { // Returns the maximum channel id or all_reduce_id used in the module group. int64 max_channel_id() const { return max_channel_id_; } + TuplePointsToAnalysis* points_to_analysis(HloModule* module) const { + return points_to_analyses_.at(module).get(); + } + private: Status Build(); @@ -242,35 +247,37 @@ class HloModuleGroupMetadata { void DumpCollectedStats() const; // List of all companion instructions sets in the module. - std::vector>> - companion_sets_; + std::vector>> companion_sets_; // Map from each companion while instruction to the index into companion_set_. - tensorflow::gtl::FlatMap companion_set_index_; + absl::flat_hash_map companion_set_index_; // Map from computation to the instruction using it (a kWhile, kConditional). - tensorflow::gtl::FlatMap + absl::flat_hash_map tracked_instructions_; // Maps tracked instructions (kWhile, kConditional, kCall, ...) to the set of // communicating instructions within the proper called computation(s). - tensorflow::gtl::FlatMap> + absl::flat_hash_map> tracked_instructions_comms_; // All channels in the module. std::vector channels_; // Map from channel ids to the index in channels_. - tensorflow::gtl::FlatMap channel_id_map_; + absl::flat_hash_map channel_id_map_; // Map from all-reduce ids to the all reduce instructions. - tensorflow::gtl::FlatMap> all_reduce_map_; + absl::flat_hash_map> all_reduce_map_; // The maximum channel id used in the module group. int64 max_channel_id_ = -1; // The modules that this metadata was built from. - const std::vector& modules_; + const std::vector modules_; + + absl::flat_hash_map> + points_to_analyses_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_module_group_test.cc b/tensorflow/compiler/xla/service/hlo_module_group_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..5a9a86af5649bf240bb5de6d30fc80b0f6a58eba --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_module_group_test.cc @@ -0,0 +1,206 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_module_group.h" + +#include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_module_group_metadata.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace xla { + +namespace { + +namespace op = ::xla::testing::opcode_matchers; + +class HloModuleGroupTest : public HloTestBase { + protected: + HloModuleGroupTest() = default; +}; + +TEST_F(HloModuleGroupTest, SingleModule) { + const string text = R"( +HloModule simple_module + +ENTRY %entry (x: f32[], y: f32[]) -> f32[] { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(%x, %y) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(text)); + HloModuleGroup group(std::move(module)); + + EXPECT_EQ(group.modules().size(), 1); + EXPECT_THAT( + group.module(0).entry_computation()->instructions(), + ::testing::ElementsAre(op::Parameter(), op::Parameter(), op::Add())); + + TF_ASSERT_OK_AND_ASSIGN(HloModuleGroup group_copy, + HloModuleGroup::CreateFromProto( + group.ToProto(), {group.module(0).config()})); + EXPECT_EQ(group_copy.modules().size(), 1); + EXPECT_THAT( + group_copy.module(0).entry_computation()->instructions(), + ::testing::ElementsAre(op::Parameter(), op::Parameter(), op::Add())); + + std::vector> modules = group.ConsumeModules(); + EXPECT_EQ(modules.size(), 1); + EXPECT_EQ(group.modules().size(), 0); +} + +TEST_F(HloModuleGroupTest, MultipleModules) { + const string text_0 = R"( +HloModule module0 + +ENTRY %entry (x: f32[], y: f32[]) -> f32[] { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(%x, %y) +} +)"; + const string text_1 = R"( +HloModule module1 + +ENTRY %entry (a: f32[]) -> f32[] { + ROOT %a = f32[] parameter(0) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module_0, + ParseHloString(text_0)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module_1, + ParseHloString(text_1)); + std::vector> modules; + modules.push_back(std::move(module_0)); + modules.push_back(std::move(module_1)); + HloModuleGroup group(TestName(), absl::MakeSpan(modules)); + EXPECT_EQ(group.modules().size(), 2); + EXPECT_THAT( + group.module(0).entry_computation()->instructions(), + ::testing::ElementsAre(op::Parameter(), op::Parameter(), op::Add())); + EXPECT_THAT(group.module(1).entry_computation()->instructions(), + ::testing::ElementsAre(op::Parameter())); + + TF_ASSERT_OK_AND_ASSIGN(HloModuleGroup group_copy, + HloModuleGroup::CreateFromProto( + group.ToProto(), {group.module(0).config(), + group.module(1).config()})); + EXPECT_EQ(group_copy.modules().size(), 2); +} + +TEST_F(HloModuleGroupTest, BuildModuleGroupByPushBack) { + const string text_0 = R"( +HloModule module0 + +ENTRY %entry (x: f32[], y: f32[]) -> f32[] { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(%x, %y) +} +)"; + const string text_1 = R"( +HloModule module1 + +ENTRY %entry (a: f32[]) -> f32[] { + ROOT %a = f32[] parameter(0) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module_0, + ParseHloString(text_0)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module_1, + ParseHloString(text_1)); + HloModuleGroup group(TestName()); + group.push_back(std::move(module_0)); + group.push_back(std::move(module_1)); + + EXPECT_EQ(group.modules().size(), 2); + EXPECT_THAT( + group.module(0).entry_computation()->instructions(), + ::testing::ElementsAre(op::Parameter(), op::Parameter(), op::Add())); + EXPECT_THAT(group.module(1).entry_computation()->instructions(), + ::testing::ElementsAre(op::Parameter())); +} + +// Tests that the order of companion instructions in the companion set doesn't +// change across runs. +TEST_F(HloModuleGroupTest, ModuleGroupCompanionOrder) { + // A simple while loop template for core i sending to core i+1. + constexpr char text[] = R"( +HloModule module_%d + +while_cond { + ROOT p = pred[] constant(true) +} + +while_body { + param = s32[] parameter(0) + token.s = token[] after-all() + token.r = token[] after-all() + send = (s32[], u32[], token[]) send(param, token.s), channel_id=%d + send-done = token[] send-done(send), channel_id=%d + recv = (s32[], u32[], token[]) recv(token.r), channel_id=%d + ROOT recv-done = (s32[], token[]) recv-done(recv), channel_id=%d +} + +ENTRY entry { + while_init = s32[] constant(1) + ROOT while = s32[] while(while_init), condition=while_cond, body=while_body +} +)"; + + // Try creating the module and the metadata kTrialCount times and check the + // companion instructions remain in the same order. + const int64 kTrialCount = 5; + const int64 kDeviceCount = 10; + std::vector companion_order; + + for (int64 t = 0; t < kTrialCount; ++t) { + HloModuleGroup group(TestName()); + for (int64 i = 0; i < kDeviceCount; ++i) { + const int64 send_channel = i; + const int64 recv_channel = i == 0 ? kDeviceCount - 1 : i - 1; + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseHloString(absl::StrFormat(text, i, send_channel, send_channel, + recv_channel, recv_channel))); + group.push_back(std::move(module)); + } + ASSERT_EQ(group.modules().size(), kDeviceCount); + + TF_ASSERT_OK_AND_ASSIGN(auto metadata, + HloModuleGroupMetadata::Build(group.modules())); + ASSERT_EQ(metadata->companion_sets().size(), 1); + + std::vector module_ids; + for (HloInstruction* companion : *metadata->companion_sets()[0]) { + module_ids.push_back(metadata->GetModuleId(companion->GetModule())); + } + + if (t == 0) { + companion_order = module_ids; + } else { + EXPECT_TRUE(absl::c_equal(companion_order, module_ids)); + } + } +} + +} // namespace + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_module_group_util.cc b/tensorflow/compiler/xla/service/hlo_module_group_util.cc index b5c7681edd8eff202b79d1c88afc419b1f6a9f3f..fddeb5f0a27a43ff9ca8b2b5d314bcfe91aaf0e6 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_util.cc +++ b/tensorflow/compiler/xla/service/hlo_module_group_util.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" @@ -32,7 +33,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -42,7 +42,7 @@ std::vector HloModuleGroupUtil::GlobalPredecessors( HloInstruction* instruction) { std::vector predecessors; // Use a vector to avoid non-determinism. - tensorflow::gtl::FlatSet unique; + absl::flat_hash_set unique; // Adds to the unique predecessors list; if the predecessors is a companion // instruction, also add companion instructions; if the predecessors is a @@ -119,7 +119,7 @@ std::vector HloModuleGroupUtil::GlobalSuccessors( HloInstruction* instruction) { std::vector successors; // Use a vector to avoid non-determinism. - tensorflow::gtl::FlatSet unique; + absl::flat_hash_set unique; // Adds to the unique successors list; if the successor is a companion // instruction, also add companion instructions; if the successor is a @@ -193,7 +193,7 @@ std::vector HloModuleGroupUtil::GlobalSuccessors( } std::vector HloModuleGroupUtil::RootInstructions( - tensorflow::gtl::ArraySlice computations) { + absl::Span computations) { std::vector roots; for (HloComputation* computation : computations) { for (HloInstruction* instruction : computation->instructions()) { @@ -282,7 +282,7 @@ Status HloModuleGroupUtil::VisitTopologicalOrder( "following nodes. Note that the order of the nodes is arbitrary " "and that the list may include nodes that are not part of the " "cycle.\n%s", - predecessor->ToString().c_str(), cyclic_instructions.c_str()); + predecessor->ToString(), cyclic_instructions); } stack.push(predecessor); } @@ -293,7 +293,7 @@ Status HloModuleGroupUtil::VisitTopologicalOrder( } Status HloModuleGroupUtil::VerifyComputations( - tensorflow::gtl::ArraySlice computations) { + absl::Span computations) { auto visit_function = [&](HloInstruction* instruction, const std::vector& instruction_group) { @@ -324,7 +324,7 @@ Status HloModuleGroupUtil::VerifyComputations( StatusOr> HloModuleGroupUtil::ComputeReachability( - tensorflow::gtl::ArraySlice computations) { + absl::Span computations) { std::vector post_order; auto visit_function = [&](HloInstruction* instruction, diff --git a/tensorflow/compiler/xla/service/hlo_module_group_util.h b/tensorflow/compiler/xla/service/hlo_module_group_util.h index c25ca1aff50b288f3ac3885cbed53e7ba9768430..f21b44bcd98d77b831de5d8a6afa4f9ddd91d15d 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_util.h +++ b/tensorflow/compiler/xla/service/hlo_module_group_util.h @@ -20,6 +20,8 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module_group_metadata.h" @@ -27,8 +29,6 @@ limitations under the License. #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/lib/gtl/flatmap.h" namespace xla { @@ -56,7 +56,7 @@ class HloModuleGroupUtil { // Returns the root instructions of the computations. std::vector RootInstructions( - tensorflow::gtl::ArraySlice computations); + absl::Span computations); // Visit state of each instruction during DFS traversal. enum VisitState { @@ -87,21 +87,20 @@ class HloModuleGroupUtil { // * visit_state: map from each instruction to its visit state. // * visit_function: function called when each instruction group. // * root: the root instruction of the traversal. - using VisitStates = tensorflow::gtl::FlatMap; + using VisitStates = absl::flat_hash_map; Status VisitTopologicalOrder(VisitStates* visit_state, const VisitFunction& visit_function, HloInstruction* root); // Verifies that the computations are well-formed (e.g., no cycles). - Status VerifyComputations( - tensorflow::gtl::ArraySlice computations); + Status VerifyComputations(absl::Span computations); // Below Reachability utils resemble those in HloComputation, except that // they can handle instructions across multiple computations. // // Creates the reachability map for the instructions in the computations. StatusOr> ComputeReachability( - tensorflow::gtl::ArraySlice computations); + absl::Span computations); // Updates the reachability of the given instruction, taking the global // predeccessorss and successors into account. diff --git a/tensorflow/compiler/xla/service/hlo_module_test.cc b/tensorflow/compiler/xla/service/hlo_module_test.cc index 209ad5e58c9360fafc3d63606e61a553de73be13..39f38b417ab0e8b54864176d8d1e0ad1a422eca6 100644 --- a/tensorflow/compiler/xla/service/hlo_module_test.cc +++ b/tensorflow/compiler/xla/service/hlo_module_test.cc @@ -19,17 +19,22 @@ limitations under the License. #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/xla_data.pb.h" - +#include "tensorflow/core/lib/core/status_test_util.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/core/lib/gtl/array_slice.h" namespace xla { namespace { +namespace op = ::xla::testing::opcode_matchers; + class HloModuleTest : public HloTestBase { protected: HloModuleTest() {} @@ -44,7 +49,7 @@ class HloModuleTest : public HloTestBase { // Creates a computation which calls the given zero-parameter computations. std::unique_ptr CreateCallComputation( - tensorflow::gtl::ArraySlice computations) { + absl::Span computations) { auto builder = HloComputation::Builder("Call"); for (auto computation : computations) { builder.AddInstruction( @@ -194,6 +199,153 @@ TEST_F(HloModuleTest, UniqueModuleId) { EXPECT_NE(module_a->unique_id(), module_b->unique_id()); } +TEST_F(HloModuleTest, ProtoSerializationWithoutSchedule) { + const string text = R"( +HloModule axpy_module + +ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] { + %alpha = f32[] parameter(0) + %x = f32[2,4]{1,0} parameter(1) + %y = f32[2,4]{1,0} parameter(2) + %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={} + %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x) + ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(text)); + ASSERT_FALSE(module->has_schedule()); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module_copy, + HloModule::CreateFromProto(module->ToProto(), module->config())); + ASSERT_FALSE(module_copy->has_schedule()); +} + +TEST_F(HloModuleTest, ProtoSerializationWithSchedule) { + const string text = R"( +HloModule axpy_module, is_scheduled=true + +ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] { + %alpha = f32[] parameter(0) + %x = f32[2,4]{1,0} parameter(1) + %y = f32[2,4]{1,0} parameter(2) + %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={} + %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x) + ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(text)); + ASSERT_TRUE(module->has_schedule()); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module_copy, + HloModule::CreateFromProto(module->ToProto(), module->config())); + ASSERT_TRUE(module_copy->has_schedule()); + TF_ASSERT_OK(module_copy->schedule().Verify()); + EXPECT_EQ(module_copy->schedule().sequences().size(), 1); + ASSERT_TRUE(module_copy->schedule().is_computation_scheduled( + module_copy->entry_computation())); + EXPECT_THAT( + module_copy->schedule() + .sequence(module_copy->entry_computation()) + .instructions(), + ::testing::ElementsAre(op::Parameter(), op::Parameter(), op::Parameter(), + op::Broadcast(), op::Multiply(), op::Add())); +} + +TEST_F(HloModuleTest, ProtoSerializationPreservesIds) { + // Verify that serializing then deserializing an HLO proto preserves the + // unique IDs of the instruction and module. + const string text = + R"(HloModule ReduceR3ToR2_module + +add_F32.v3 { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) +} + +ENTRY ReduceR3ToR2.v3 { + input = f32[8,16,256]{2,1,0} parameter(0) + constant = f32[] constant(0) + ROOT reduce = f32[8,16]{1,0} reduce(input, constant), dimensions={2}, to_apply=add_F32.v3 +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(text)); + + // Perform various transformations on the graph: + // + // * clone the reduction function + // * replace use of reduction function with the clone. + // * add a random instruction to the entry computation. + // + // This will create instruction and computation IDs which are interesting: + // not consecutive and not densely packed. + HloComputation* entry = module->entry_computation(); + HloInstruction* root = entry->root_instruction(); + HloComputation* reduction = root->to_apply(); + HloComputation* reduction_clone = + module->AddEmbeddedComputation(reduction->Clone()); + root->set_to_apply(reduction_clone); + TF_ASSERT_OK(module->RemoveEmbeddedComputation(reduction)); + HloInstruction* negate = entry->AddInstruction( + HloInstruction::CreateUnary(root->shape(), HloOpcode::kNegate, root)); + entry->set_root_instruction(negate); + + // Schedule the transformed module, this verifies that the serialized schedule + // is robust against non-consecutive IDs as well (b/114712358). + auto size_fn = [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape()); + }; + HloMemoryScheduler scheduler(size_fn); + TF_ASSERT_OK(scheduler.Run(module.get()).status()); + ASSERT_TRUE(module->has_schedule()); + + // Serialize and deserialize and verify that the instruction and computations + // unique ids are the same. + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module_copy, + HloModule::CreateFromProto(module->ToProto(), module->config())); + + // The module IDs should *not* be the same because module ids must be globally + // unique. + EXPECT_NE(module->unique_id(), module_copy->unique_id()); + + // Verify that the computations and instructions all have the same unique id. + auto computation_copy_it = module_copy->computations().begin(); + for (const HloComputation* computation_orig : module->computations()) { + const HloComputation* computation_copy = *computation_copy_it++; + EXPECT_EQ(computation_orig->unique_id(), computation_copy->unique_id()) + << absl::StrFormat( + "ID of original computation %s != ID of deserialized " + "computation %s: %d != %d", + computation_orig->name(), computation_copy->name(), + computation_orig->unique_id(), computation_copy->unique_id()); + + auto instruction_copy_it = computation_copy->instructions().begin(); + for (const HloInstruction* instruction_orig : + computation_orig->instructions()) { + const HloInstruction* instruction_copy = *instruction_copy_it++; + EXPECT_EQ(instruction_orig->unique_id(), instruction_copy->unique_id()) + << absl::StrFormat( + "ID of original instruction %s != ID of deserialized " + "instruction %s: %d != %d", + instruction_orig->name(), instruction_copy->name(), + instruction_orig->unique_id(), instruction_copy->unique_id()); + } + } + + // Verify that the next unique ID which the module would have handed out is + // greater than the unique id of any instruction. + int next_id = module_copy->NewUniqueInstructionId(); + for (const HloComputation* computation : module_copy->computations()) { + for (const HloInstruction* instruction : computation->instructions()) { + EXPECT_GT(next_id, instruction->unique_id()); + } + } +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_opcode.cc b/tensorflow/compiler/xla/service/hlo_opcode.cc index d1eaf357855205f1e9867e86f3042b96b6beff97..4551a1c2e259b06818f913cb6a9e782436b7e594 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.cc +++ b/tensorflow/compiler/xla/service/hlo_opcode.cc @@ -14,9 +14,9 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "absl/container/flat_hash_map.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/gtl/flatmap.h" namespace xla { @@ -31,7 +31,7 @@ string HloOpcodeString(HloOpcode opcode) { } StatusOr StringToHloOpcode(const string& opcode_name) { - static auto* opcode_map = new tensorflow::gtl::FlatMap({ + static auto* opcode_map = new absl::flat_hash_map({ #define STRING_TO_OPCODE_ENTRY(enum_name, opcode_name, ...) \ {opcode_name, HloOpcode::enum_name}, HLO_OPCODE_LIST(STRING_TO_OPCODE_ENTRY) @@ -39,7 +39,7 @@ StatusOr StringToHloOpcode(const string& opcode_name) { }); auto it = opcode_map->find(opcode_name); if (it == opcode_map->end()) { - return InvalidArgument("Unknown opcode: %s", opcode_name.c_str()); + return InvalidArgument("Unknown opcode: %s", opcode_name); } return it->second; } diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h index b8f2a21ff9df6460303610cf64c98d1b96836171..e6bfb8025d4bfeba1d334d1f946e33841a2da092 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.h +++ b/tensorflow/compiler/xla/service/hlo_opcode.h @@ -58,6 +58,7 @@ namespace xla { V(kCall, "call", kHloOpcodeIsVariadic) \ V(kCeil, "ceil") \ V(kClamp, "clamp") \ + V(kCollectivePermute, "collective-permute") \ V(kClz, "count-leading-zeros") \ V(kComplex, "complex") \ V(kConcatenate, "concatenate", kHloOpcodeIsVariadic) \ diff --git a/tensorflow/compiler/xla/service/hlo_ordering.cc b/tensorflow/compiler/xla/service/hlo_ordering.cc index 8fe91c7278dba073a02283575f80780f23d1be83..23d41d91d6969ddf9062507e926ae39c1e1315d4 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering.cc @@ -18,6 +18,8 @@ limitations under the License. #include #include +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -26,7 +28,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" namespace xla { @@ -91,14 +92,18 @@ bool HloOrdering::ExecutesBefore(const HloInstruction* a, } bool HloOrdering::IsDefinedBefore(const HloValue& a, const HloValue& b) const { - // If 'b' is an entry param then 'a' cannot be defined before 'b' because 'b' - // is live into the module. + // Entry parameter should always be defined before other instructions. const HloModule* module = b.defining_instruction()->parent()->parent(); if (b.defining_instruction()->parent() == module->entry_computation() && b.defining_instruction()->opcode() == HloOpcode::kParameter) { return false; } + if (a.defining_instruction()->parent() == module->entry_computation() && + a.defining_instruction()->opcode() == HloOpcode::kParameter) { + return true; + } + // Phi values require special handling. Because XLA does not have a phi // instruction, the definition instruction of the phis values are // placeholders: either the subcomputation parameter (body or condition) or @@ -252,6 +257,12 @@ bool HloOrdering::LiveRangeStrictlyBefore( VLOG(4) << a << " not defined before " << b; return false; } + + if (a.live_out_of_module()) { + VLOG(4) << a << " is live out of module and defined before " << b; + return false; + } + // All uses of 'a' must be before 'b' is defined. for (const HloUse& use : a.uses()) { if (dataflow.DoesNotUseOperandBuffer(a.instruction(), a.index(), @@ -264,6 +275,18 @@ bool HloOrdering::LiveRangeStrictlyBefore( return false; } } + + if (a.instruction()->parent() == b.instruction()->parent()) { + for (const HloPosition& position : a.positions()) { + if (position.instruction == + a.instruction()->parent()->root_instruction()) { + VLOG(4) << a << " is live out of computation and defined before " << b + << " which is in same computation"; + return false; + } + } + } + return true; } @@ -274,23 +297,6 @@ bool HloOrdering::MayInterfere(const HloValue& a, const HloValue& b, !LiveRangeStrictlyBefore(b, a, dataflow); } -HloOrderingProto HloOrdering::ToProto() const { - HloOrderingProto proto; - for (const auto& computation : module_->computations()) { - const std::vector* sequence = - SequentialOrder(*computation); - if (sequence != nullptr) { - HloOrderingProto::SequentialComputation* proto_computation = - proto.add_sequential_computations(); - proto_computation->set_computation_name(computation->name()); - for (const HloInstruction* instruction : *sequence) { - *proto_computation->add_instruction_names() = instruction->name(); - } - } - } - return proto; -} - PredecessorHloOrdering::PredecessorHloOrdering(const HloModule* module) : HloOrdering(module) {} @@ -306,17 +312,15 @@ string PredecessorHloOrdering::ToStringHelper(const string& name) const { std::vector pieces; pieces.push_back(name); for (auto* computation : module_->MakeNonfusionComputations()) { - pieces.push_back(tensorflow::strings::Printf("computation %s:", - computation->name().c_str())); + pieces.push_back(absl::StrFormat("computation %s:", computation->name())); const auto all = computation->MakeInstructionPostOrder(); for (auto instruction : all) { - pieces.push_back(tensorflow::strings::Printf( - " %s predecessors:", instruction->name().c_str())); + pieces.push_back( + absl::StrFormat(" %s predecessors:", instruction->name())); for (auto predecessor : all) { if (predecessors_.at(computation) ->IsReachable(predecessor, instruction)) { - pieces.push_back( - tensorflow::strings::Printf(" %s", predecessor->name().c_str())); + pieces.push_back(absl::StrFormat(" %s", predecessor->name())); } } } @@ -338,15 +342,24 @@ string DependencyHloOrdering::ToString() const { return ToStringHelper("DependencyHloOrdering"); } -SequentialHloOrdering::SequentialHloOrdering( - const HloModule* module, const HloModuleSequence& module_sequence) - : HloOrdering(module), module_sequence_(module_sequence) { +SequentialHloOrdering::SequentialHloOrdering(const HloSchedule& schedule) + : HloOrdering(schedule.module()), schedule_(schedule) { + Initialize(); +} + +SequentialHloOrdering::SequentialHloOrdering(HloSchedule&& schedule) + : HloOrdering(schedule.module()), schedule_(std::move(schedule)) { + Initialize(); +} + +void SequentialHloOrdering::Initialize() { // Create a map from instruction to its order position. - for (auto computation_order : module_sequence_) { - const std::vector& order = computation_order.second; + TF_DCHECK_OK(schedule_.Verify()); + for (const auto& computation_sequence : schedule_.sequences()) { + const std::vector& order = + computation_sequence.second.instructions(); for (int i = 0; i < order.size(); ++i) { - DCHECK_EQ(0, order_position_.count(order[i])); - order_position_.emplace(order[i], i); + InsertOrDie(&order_position_, order[i], i); } } } @@ -364,50 +377,13 @@ bool SequentialHloOrdering::ExecutesBeforeInSameComputation( const std::vector* SequentialHloOrdering::SequentialOrder( const HloComputation& computation) const { - auto find_it = module_sequence_.find(&computation); - return find_it == module_sequence_.end() ? nullptr : &find_it->second; + return schedule_.is_computation_scheduled(&computation) + ? &schedule_.sequence(&computation).instructions() + : nullptr; } string SequentialHloOrdering::ToString() const { - std::vector pieces; - pieces.push_back("SequentialHloOrdering"); - for (auto* computation : module_->computations()) { - pieces.push_back(tensorflow::strings::Printf("computation %s order:", - computation->name().c_str())); - // Gather all instructions in the module sequence for this computation and - // sort them by their position. - std::vector instructions; - for (auto& instruction_position : order_position_) { - const HloInstruction* instruction = instruction_position.first; - if (instruction->parent() == computation) { - instructions.push_back(instruction); - } - } - std::sort(instructions.begin(), instructions.end(), - [this](const HloInstruction* a, const HloInstruction* b) { - return order_position_.at(a) < order_position_.at(b); - }); - for (auto instruction : instructions) { - pieces.push_back( - tensorflow::strings::Printf(" %s", instruction->name().c_str())); - } - } - return absl::StrJoin(pieces, "\n"); -} - -std::ostream& operator<<( - std::ostream& out, - const SequentialHloOrdering::HloModuleSequence& module_sequence) { - for (auto computation_pair : module_sequence) { - const HloComputation* computation = computation_pair.first; - const std::vector& computation_sequence = - computation_pair.second; - out << "Computation " << computation->name() << ":\n"; - for (auto* instruction : computation_sequence) { - out << " " << instruction->name() << "\n"; - } - } - return out; + return absl::StrCat("SequentialHloOrdering\n", schedule_.ToString()); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_ordering.h b/tensorflow/compiler/xla/service/hlo_ordering.h index 985f3fa64d8767b0c0063ee900f7d11c3b7f6d4a..66313492eb2dd10ac9a6000639ddb8991b367c0f 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering.h +++ b/tensorflow/compiler/xla/service/hlo_ordering.h @@ -20,14 +20,15 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" #include "tensorflow/compiler/xla/service/call_graph.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_schedule.h" #include "tensorflow/compiler/xla/service/hlo_value.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/gtl/flatmap.h" namespace xla { @@ -71,10 +72,6 @@ class HloOrdering { virtual string ToString() const = 0; - // Returns the serialized representation of this ordering. - // Only sequential computation orders are represented. - HloOrderingProto ToProto() const; - protected: // Returns true if instruction 'a' executes before instruction 'b'. // Precondition: 'a' and 'b' are in the same computation. @@ -123,8 +120,8 @@ class PredecessorHloOrdering : public HloOrdering { // predecessors. An instruction is an element of its own predecessor set. // // Subclasses should fill this in to define the desired ordering. - tensorflow::gtl::FlatMap> + absl::flat_hash_map> predecessors_; }; @@ -183,17 +180,8 @@ class DependencyHloOrdering : public PredecessorHloOrdering { // interference is reduced relative to DependencyHloOrdering. class SequentialHloOrdering : public HloOrdering { public: - // TODO(dimvar): HloModuleSequence is not a good name because it sounds like - // a sequence of modules, instead of a map of schedules for all computations - // in a module. We should change it at some point. - // - // A sequence of instructions for each computation in the module. - using HloModuleSequence = - tensorflow::gtl::FlatMap>; - - SequentialHloOrdering(const HloModule* module, - const HloModuleSequence& module_sequence); + SequentialHloOrdering(const HloSchedule& schedule); + SequentialHloOrdering(HloSchedule&& schedule); ~SequentialHloOrdering() override = default; // Returns the sequential instruction order for the given computation. @@ -203,10 +191,12 @@ class SequentialHloOrdering : public HloOrdering { string ToString() const override; protected: + void Initialize(); + bool ExecutesBeforeInSameComputation(const HloInstruction* a, const HloInstruction* b) const override; - const HloModuleSequence module_sequence_; + const HloSchedule schedule_; // The position of every instruction in the HLO module in its respective // computation sequence (a value of zero indicates the instruction is first in @@ -214,13 +204,9 @@ class SequentialHloOrdering : public HloOrdering { // this map so more than one instruction may have the same position // value. This is not a problem because ExecutesBefore also verifies // instructions are in the same computation. - tensorflow::gtl::FlatMap order_position_; + absl::flat_hash_map order_position_; }; -std::ostream& operator<<( - std::ostream& out, - const SequentialHloOrdering::HloModuleSequence& module_sequence); - } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_ORDERING_H_ diff --git a/tensorflow/compiler/xla/service/hlo_ordering_test.cc b/tensorflow/compiler/xla/service/hlo_ordering_test.cc index 126d3a2d9c70bff1d2a022e395652049768d6d21..b045adc9640ac0ca8cf4a127fea2fbfcbb1aaf3f 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering_test.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering_test.cc @@ -23,11 +23,12 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" -#include "tensorflow/compiler/xla/service/hlo_scheduling.h" +#include "tensorflow/compiler/xla/service/hlo_schedule.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" namespace xla { namespace { @@ -173,6 +174,26 @@ TEST_F(HloOrderingTest, InstructionsInWhileComputations) { EXPECT_FALSE(ordering.ExecutesBefore(body_param, cond_param)); } +TEST_F(HloOrderingTest, ParametersDefinedBeforeOthers) { + // Entry parameter should always be defined before other instruction. + auto module = CreateNewModule(); + const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {}); + auto builder = HloComputation::Builder(TestName()); + auto constant = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "param")); + module->AddEntryComputation(builder.Build()); + TF_ASSERT_OK_AND_ASSIGN(auto dataflow, + HloDataflowAnalysis::Run(*module, /*ssa_form=*/true)); + + DependencyHloOrdering ordering(module.get()); + EXPECT_TRUE(ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(param), + dataflow->GetValueDefinedAt(constant))); + EXPECT_TRUE(!ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(constant), + dataflow->GetValueDefinedAt(param))); +} + TEST_F(HloOrderingTest, ValuesInWhileComputations) { // Tests the ordering of values (defined by dataflow analysis) in the body and // condition of a while instruction. HLO code: @@ -376,5 +397,104 @@ ENTRY root { dataflow->GetValueDefinedAt(add_3))); } +TEST_F(HloOrderingTest, + ValuesLiveOutOfModuleInterfereWithInstructionsAfterRoot) { + // Tests that values live out of the module should interfere with values + // defined after the root instruction. That is: + // + // %param = param(0) + // ROOT %root = negate(%param) + // %dead = Constant(123.0) + // + // %root should interfere with %dead. + auto module = CreateNewModule(); + const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {}); + + auto builder = HloComputation::Builder(TestName()); + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "param")); + HloInstruction* root = builder.AddInstruction( + HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, param)); + HloInstruction* dead = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(123.0f))); + HloComputation* entry = + module->AddEntryComputation(builder.Build(/*root_instruction=*/root)); + + HloSchedule schedule(module.get()); + schedule.set_sequence(entry, {param, root, dead}); + TF_ASSERT_OK(schedule.Verify()); + SequentialHloOrdering ordering(schedule); + + TF_ASSERT_OK_AND_ASSIGN(auto dataflow, + HloDataflowAnalysis::Run(*module, /*ssa_form=*/true)); + + EXPECT_TRUE(ordering.ExecutesBefore(root, dead)); + EXPECT_FALSE(ordering.ExecutesBefore(dead, root)); + + EXPECT_FALSE(ordering.LiveRangeStrictlyBefore( + dataflow->GetValueDefinedAt(root), dataflow->GetValueDefinedAt(dead), + *dataflow)); + + EXPECT_TRUE(ordering.MayInterfere(dataflow->GetValueDefinedAt(root), + dataflow->GetValueDefinedAt(dead), + *dataflow)); +} + +TEST_F(HloOrderingTest, + ValuesLiveOutOfComputationInterfereWithInstructionsAfterRoot) { + // Tests that values live out of a computation should interfere with values + // defined after the root instruction of the computation. That is: + // + // subcomputation: + // %param = param(0) + // ROOT %root = negate(%param) + // %dead = Constant(123.0) + // + // entry computation: + // %c = constant(42.0) + // ROOT %call = call({%c}), subcomputation + // + // %root should interfere with %dead. + auto module = CreateNewModule(); + const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {}); + + auto subbuilder = HloComputation::Builder(TestName() + ".sub"); + HloInstruction* param = subbuilder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "param")); + HloInstruction* root = subbuilder.AddInstruction( + HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, param)); + HloInstruction* dead = subbuilder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(123.0f))); + HloComputation* subcomputation = module->AddEmbeddedComputation( + subbuilder.Build(/*root_instruction=*/root)); + + auto builder = HloComputation::Builder(TestName()); + HloInstruction* c = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + HloInstruction* call = builder.AddInstruction( + HloInstruction::CreateCall(scalar_shape, {c}, subcomputation)); + HloComputation* entry = module->AddEntryComputation(builder.Build()); + + HloSchedule schedule(module.get()); + schedule.set_sequence(subcomputation, {param, root, dead}); + schedule.set_sequence(entry, {c, call}); + TF_ASSERT_OK(schedule.Verify()); + SequentialHloOrdering ordering(schedule); + + TF_ASSERT_OK_AND_ASSIGN(auto dataflow, + HloDataflowAnalysis::Run(*module, /*ssa_form=*/true)); + + EXPECT_TRUE(ordering.ExecutesBefore(root, dead)); + EXPECT_FALSE(ordering.ExecutesBefore(dead, root)); + + EXPECT_FALSE(ordering.LiveRangeStrictlyBefore( + dataflow->GetValueDefinedAt(root), dataflow->GetValueDefinedAt(dead), + *dataflow)); + + EXPECT_TRUE(ordering.MayInterfere(dataflow->GetValueDefinedAt(root), + dataflow->GetValueDefinedAt(dead), + *dataflow)); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index df789e6222fe1574fc4a45e6200f69fa95c9a81f..81f091238e5725f64b953f70b82d52cc90aef8ea 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -18,6 +18,7 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "absl/strings/str_split.h" #include "tensorflow/compiler/xla/literal.h" @@ -25,11 +26,11 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_domain_metadata.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_schedule.h" #include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/gtl/map_util.h" -#include "tensorflow/core/lib/strings/stringprintf.h" namespace xla { @@ -39,24 +40,35 @@ using absl::nullopt; using absl::optional; using absl::StrAppend; using absl::StrCat; +using absl::StrFormat; using absl::StrJoin; -using ::tensorflow::strings::Printf; const double kF16max = 65504; +// Creates and returns a schedule created using the order of the instructions in +// the HloComputation::instructions() vectors in the module. +HloSchedule ScheduleFromInstructionOrder(const HloModule* module) { + HloSchedule schedule(module); + for (const HloComputation* computation : module->computations()) { + if (!computation->IsFusionComputation()) { + for (const HloInstruction* instruction : computation->instructions()) { + schedule.GetOrCreateSequence(computation).push_back(instruction); + } + } + } + return schedule; +} + // Parser for the HloModule::ToString() format text. class HloParser { public: using LocTy = HloLexer::LocTy; - explicit HloParser(absl::string_view str, const HloModuleConfig& config) - : lexer_(str), config_(config) {} - - // Runs the parser. Returns false if an error occurred. - bool Run(); + explicit HloParser(absl::string_view str) : lexer_(str) {} - // Returns the parsed HloModule. - std::unique_ptr ConsumeHloModule() { return std::move(module_); } + // Runs the parser and constructs the resulting HLO in the given (empty) + // HloModule. Returns false if an error occurred. + Status Run(HloModule* module); // Returns the error information. string GetError() const { return StrJoin(error_, "\n"); } @@ -65,40 +77,47 @@ class HloParser { StatusOr ParseShardingOnly(); StatusOr ParseWindowOnly(); StatusOr ParseConvolutionDimensionNumbersOnly(); - - // Stand-alone parsing utility for a single instruction worth of text. - Status ParseSingleInstruction(HloComputation::Builder* builder, - string* root_name); + StatusOr ParsePaddingConfigOnly(); private: - // Locates an instruction with the given name in the instruction_pool_ or + using InstrNameTable = + std::unordered_map>; + + // Returns the map from the instruction name to the instruction itself and its + // location in the current scope. + InstrNameTable& current_name_table() { return scoped_name_tables_.back(); } + + // Locates an instruction with the given name in the current_name_table() or // returns nullptr. // - // If the missing_instruction_hook_ is registered and a "shape" is provided, - // the hook will be called and may satisfy the request for the given - // instruction. This is useful when we reify parameters as they're resolved; - // i.e. for ParseSingleInstruction. + // When the name is not found or name is empty, if create_missing_instruction_ + // hook is registered and a "shape" is provided, the hook will be called to + // create an instruction. This is useful when we reify parameters as they're + // resolved; i.e. for ParseSingleInstruction. std::pair* FindInstruction( const string& name, const optional& shape = nullopt); + // Parse a single instruction worth of text. + bool ParseSingleInstruction(HloModule* module); + // ParseXXX returns false if an error occurred. - bool ParseHloModule(); - bool ParseComputations(); + bool ParseHloModule(HloModule* module); + + bool ParseComputations(HloModule* module); bool ParseComputation(HloComputation** entry_computation); - bool ParseInstructionList(HloComputation::Builder* builder, - string* root_name); + bool ParseInstructionList(HloComputation** computation, + const string& computation_name); bool ParseInstruction(HloComputation::Builder* builder, string* root_name); + bool ParseInstruciontRhs(HloComputation::Builder* builder, const string& name, + LocTy name_loc); bool ParseControlPredecessors(HloInstruction* instruction); - bool ParseLiteral(std::unique_ptr* literal, const Shape& shape); - bool ParseTupleLiteral(std::unique_ptr* literal, const Shape& shape); - bool ParseNonTupleLiteral(std::unique_ptr* literal, - const Shape& shape); - bool ParseDenseLiteral(std::unique_ptr* literal, const Shape& shape); - bool ParseSparseLiteral(std::unique_ptr* literal, - const Shape& shape); + bool ParseLiteral(Literal* literal, const Shape& shape); + bool ParseTupleLiteral(Literal* literal, const Shape& shape); + bool ParseNonTupleLiteral(Literal* literal, const Shape& shape); + bool ParseDenseLiteral(Literal* literal, const Shape& shape); + bool ParseSparseLiteral(Literal* literal, const Shape& shape); template - bool ParseSparseLiteralHelper(std::unique_ptr* literal, - const Shape& shape); + bool ParseSparseLiteralHelper(Literal* literal, const Shape& shape); // Sets the sub-value of literal at the given index to the given value. The // literal's shape must have the default layout. @@ -155,6 +174,7 @@ class HloParser { kDistribution, kDomain, kPrecisionList, + kShapeList }; struct AttrConfig { @@ -220,7 +240,8 @@ class HloParser { bool ParseWindowPad(std::vector>* pad); bool ParseSliceRanges(SliceRanges* result); - bool ParsePrecisionList(std::vector* result); + bool ParsePrecisionList(std::vector* result); + bool ParseShapeList(std::vector* result); bool ParseInt64List(const TokKind start, const TokKind end, const TokKind delim, std::vector* result); @@ -239,7 +260,7 @@ class HloParser { bool ParseFftType(FftType* result); bool ParseFusionKind(HloInstruction::FusionKind* result); bool ParseRandomDistribution(RandomDistribution* result); - bool ParsePrecision(PrecisionConfigProto::Precision* result); + bool ParsePrecision(PrecisionConfig::Precision* result); bool ParseInt64(tensorflow::int64* result); bool ParseDouble(double* result); bool ParseBool(bool* result); @@ -271,25 +292,47 @@ class HloParser { bool AddComputation(const string& name, HloComputation* computation, LocTy name_loc); - // The map from the instruction/computation name to the - // instruction/computation itself and it's location. This does not own the - // pointers. - std::unordered_map> - instruction_pool_; + HloLexer lexer_; + + // A stack for the instruction names. The top of the stack stores the + // instruction name table for the current scope. + // + // A instruction's name is unique among its scope (i.e. its parent + // computation), but it's not necessarily unique among all computations in the + // module. When there are multiple levels of nested computations, the same + // name could appear in both an outer computation and an inner computation. So + // we need a stack to make sure a name is only visible within its scope, + std::vector scoped_name_tables_; + + // A helper class which pushes and pops to an InstrNameTable stack via RAII. + class Scope { + public: + explicit Scope(std::vector* scoped_name_tables) + : scoped_name_tables_(scoped_name_tables) { + scoped_name_tables_->emplace_back(); + } + ~Scope() { scoped_name_tables_->pop_back(); } + + private: + std::vector* scoped_name_tables_; + }; + + // Map from the computation name to the computation itself and its location. std::unordered_map> computation_pool_; - HloLexer lexer_; - std::unique_ptr module_; std::vector> computations_; - const HloModuleConfig config_; std::vector error_; - // Function that gets invoked when we try to resolve an instruction - // instruction_pool_ but fail to do so. - std::function*(string, - const optional&)> - missing_instruction_hook_; + // When an operand name cannot be resolved, this function is called to create + // a parameter instruction with the given name and shape. It registers the + // name, instruction, and a placeholder location in the name table. It returns + // the newly-created instruction and the placeholder location. If `name` is + // empty, this should create the parameter with a generated name. This is + // supposed to be set and used only in ParseSingleInstruction. + std::function*(const string& name, + const Shape& shape)> + create_missing_instruction_; }; bool SplitToInt64s(absl::string_view s, char delim, std::vector* out) { @@ -306,7 +349,7 @@ bool SplitToInt64s(absl::string_view s, char delim, std::vector* out) { // Creates replica groups from the provided nested array. groups[i] represents // the replica ids for group 'i'. std::vector CreateReplicaGroups( - tensorflow::gtl::ArraySlice> groups) { + absl::Span> groups) { std::vector replica_groups; absl::c_transform(groups, std::back_inserter(replica_groups), [](const std::vector& ids) { @@ -324,7 +367,7 @@ bool HloParser::Error(LocTy loc, absl::string_view msg) { std::vector error_lines; error_lines.push_back( StrCat("was parsing ", line, ":", col, ": error: ", msg)); - error_lines.push_back(std::string(lexer_.GetLine(loc))); + error_lines.emplace_back(lexer_.GetLine(loc)); error_lines.push_back(col == 0 ? "" : StrCat(string(col - 1, ' '), "^")); error_.push_back(StrJoin(error_lines, "\n")); @@ -336,24 +379,50 @@ bool HloParser::TokenError(absl::string_view msg) { return Error(lexer_.GetLoc(), msg); } -bool HloParser::Run() { +Status HloParser::Run(HloModule* module) { lexer_.Lex(); - return ParseHloModule(); + if (lexer_.GetKind() == TokKind::kw_HloModule) { + // This means that the text contains a full HLO module. + if (!ParseHloModule(module)) { + return InvalidArgument( + "Syntax error when trying to parse the text as a HloModule:\n%s", + GetError()); + } + return Status::OK(); + } + // This means that the text is a single HLO instruction. + if (!ParseSingleInstruction(module)) { + return InvalidArgument( + "Syntax error when trying to parse the text as a single " + "HloInstruction:\n%s", + GetError()); + } + return Status::OK(); } std::pair* HloParser::FindInstruction( const string& name, const optional& shape) { - std::pair* instr = - tensorflow::gtl::FindOrNull(instruction_pool_, name); + std::pair* instr = nullptr; + if (!name.empty()) { + instr = tensorflow::gtl::FindOrNull(current_name_table(), name); + } + // Potentially call the missing instruction hook. - if (instr == nullptr && missing_instruction_hook_ != nullptr) { - return missing_instruction_hook_(name, shape); + if (instr == nullptr && create_missing_instruction_ != nullptr && + scoped_name_tables_.size() == 1) { + if (!shape.has_value()) { + Error(lexer_.GetLoc(), + "Operand had no shape in HLO text; cannot create parameter for " + "single-instruction module."); + return nullptr; + } + return create_missing_instruction_(name, *shape); } return instr; } // ::= 'HloModule' name computations -bool HloParser::ParseHloModule() { +bool HloParser::ParseHloModule(HloModule* module) { if (lexer_.GetKind() != TokKind::kw_HloModule) { return TokenError("expects HloModule"); } @@ -365,13 +434,27 @@ bool HloParser::ParseHloModule() { return false; } - module_ = absl::make_unique(name, config_); + absl::optional is_scheduled; + std::unordered_map attrs; + attrs["is_scheduled"] = {/*required=*/false, AttrTy::kBool, &is_scheduled}; + if (!ParseAttributes(attrs)) { + return false; + } + + module->set_name(name); + if (!ParseComputations(module)) { + return false; + } + + if (is_scheduled.has_value() && *is_scheduled) { + TF_CHECK_OK(module->set_schedule(ScheduleFromInstructionOrder(module))); + } - return ParseComputations(); + return true; } // computations ::= (computation)+ -bool HloParser::ParseComputations() { +bool HloParser::ParseComputations(HloModule* module) { HloComputation* entry_computation = nullptr; do { if (!ParseComputation(&entry_computation)) { @@ -387,21 +470,20 @@ bool HloParser::ParseComputations() { if ((entry_computation != nullptr && computations_[i].get() != entry_computation) || (entry_computation == nullptr && i != computations_.size() - 1)) { - module_->AddEmbeddedComputation(std::move(computations_[i])); + module->AddEmbeddedComputation(std::move(computations_[i])); continue; } - auto computation = - module_->AddEntryComputation(std::move(computations_[i])); + auto computation = module->AddEntryComputation(std::move(computations_[i])); // The parameters and result layouts were set to default layout. Here we // set the layouts to what the hlo text says. for (int p = 0; p < computation->num_parameters(); p++) { const Shape& param_shape = computation->parameter_instruction(p)->shape(); - TF_CHECK_OK(module_->mutable_entry_computation_layout() + TF_CHECK_OK(module->mutable_entry_computation_layout() ->mutable_parameter_layout(p) ->CopyLayoutFromShape(param_shape)); } const Shape& result_shape = computation->root_instruction()->shape(); - TF_CHECK_OK(module_->mutable_entry_computation_layout() + TF_CHECK_OK(module->mutable_entry_computation_layout() ->mutable_result_layout() ->CopyLayoutFromShape(result_shape)); } @@ -418,7 +500,6 @@ bool HloParser::ParseComputation(HloComputation** entry_computation) { if (!ParseName(&name)) { return false; } - auto builder = absl::make_unique(name); LocTy shape_loc = nullptr; Shape shape; @@ -426,40 +507,21 @@ bool HloParser::ParseComputation(HloComputation** entry_computation) { return false; } - string root_name; - if (!ParseInstructionList(builder.get(), &root_name)) { + HloComputation* computation = nullptr; + if (!ParseInstructionList(&computation, name)) { return false; } - std::pair* root_node = FindInstruction(root_name); - // This means some instruction was marked as ROOT but we didn't find it in the - // pool, which should not happen. - if (!root_name.empty() && root_node == nullptr) { - LOG(FATAL) << "instruction " << root_name - << " was marked as ROOT but the parser has not seen it before"; - } - - HloInstruction* root = root_node == nullptr ? nullptr : root_node->first; - // Now root can be either an existing instruction or a nullptr. If it's a - // nullptr, the implementation of Builder will set the last instruction as - // root instruction. - computations_.emplace_back(builder->Build(root)); - HloComputation* computation = computations_.back().get(); - - if (!root) { - root = computation->root_instruction(); - } else { - CHECK_EQ(root, computation->root_instruction()); - } - // If param_list_to_shape was present, check compatibility. - if (shape_loc != nullptr && !ShapeUtil::Compatible(root->shape(), shape)) { + if (shape_loc != nullptr && + !ShapeUtil::Compatible(computation->root_instruction()->shape(), shape)) { return Error( shape_loc, - StrCat("Shape of computation ", name, ", ", - ShapeUtil::HumanString(shape), - ", is not compatible with that of its root instruction ", - root_name, ", ", ShapeUtil::HumanString(root->shape()))); + StrCat( + "Shape of computation ", name, ", ", ShapeUtil::HumanString(shape), + ", is not compatible with that of its root instruction ", + computation->root_instruction()->name(), ", ", + ShapeUtil::HumanString(computation->root_instruction()->shape()))); } if (is_entry_computation) { @@ -468,43 +530,62 @@ bool HloParser::ParseComputation(HloComputation** entry_computation) { } *entry_computation = computation; } - instruction_pool_.clear(); return AddComputation(name, computation, name_loc); } // instruction_list ::= '{' instruction_list1 '}' // instruction_list1 ::= (instruction)+ -bool HloParser::ParseInstructionList(HloComputation::Builder* builder, - string* root_name) { +bool HloParser::ParseInstructionList(HloComputation** computation, + const string& computation_name) { + Scope scope(&scoped_name_tables_); + HloComputation::Builder builder(computation_name); if (!ParseToken(TokKind::kLbrace, "expects '{' at the beginning of instruction list.")) { return false; } + string root_name; do { - if (!ParseInstruction(builder, root_name)) { + if (!ParseInstruction(&builder, &root_name)) { return false; } } while (lexer_.GetKind() != TokKind::kRbrace); - return ParseToken(TokKind::kRbrace, - "expects '}' at the end of instruction list."); + if (!ParseToken(TokKind::kRbrace, + "expects '}' at the end of instruction list.")) { + return false; + } + HloInstruction* root = nullptr; + if (!root_name.empty()) { + std::pair* root_node = + tensorflow::gtl::FindOrNull(current_name_table(), root_name); + + // This means some instruction was marked as ROOT but we didn't find it in + // the pool, which should not happen. + if (root_node == nullptr) { + LOG(FATAL) << "instruction " << root_name + << " was marked as ROOT but the parser has not seen it before"; + } + root = root_node->first; + } + + // Now root can be either an existing instruction or a nullptr. If it's a + // nullptr, the implementation of Builder will set the last instruction as + // the root instruction. + computations_.emplace_back(builder.Build(root)); + *computation = computations_.back().get(); + return true; } // instruction ::= ('ROOT')? name '=' shape opcode operands (attribute)* bool HloParser::ParseInstruction(HloComputation::Builder* builder, string* root_name) { string name; - Shape shape; - HloOpcode opcode; - std::vector operands; - LocTy maybe_root_loc = lexer_.GetLoc(); bool is_root = EatIfPresent(TokKind::kw_ROOT); const LocTy name_loc = lexer_.GetLoc(); if (!ParseName(&name) || - !ParseToken(TokKind::kEqual, "expects '=' in instruction") || - !ParseShape(&shape) || !ParseOpcode(&opcode)) { + !ParseToken(TokKind::kEqual, "expects '=' in instruction")) { return false; } @@ -515,6 +596,19 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, *root_name = name; } + return ParseInstruciontRhs(builder, name, name_loc); +} + +bool HloParser::ParseInstruciontRhs(HloComputation::Builder* builder, + const string& name, LocTy name_loc) { + Shape shape; + HloOpcode opcode; + std::vector operands; + + if (!ParseShape(&shape) || !ParseOpcode(&opcode)) { + return false; + } + // Add optional attributes. std::unordered_map attrs; optional sharding; @@ -529,10 +623,6 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, attrs["backend_config"] = {/*required=*/false, AttrTy::kString, &backend_config}; - optional> operand_precision; - attrs["operand_precision"] = {/*required=*/false, AttrTy::kPrecisionList, - &operand_precision}; - HloInstruction* instruction; switch (opcode) { case HloOpcode::kParameter: { @@ -549,7 +639,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kConstant: { - std::unique_ptr literal; + Literal literal; if (!ParseToken(TokKind::kLparen, "expects '(' before constant literal") || !ParseLiteral(&literal, shape) || @@ -562,11 +652,15 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kIota: { + optional iota_dimension; + attrs["iota_dimension"] = {/*required=*/true, AttrTy::kInt64, + &iota_dimension}; if (!ParseOperands(&operands, /*expected_size=*/0) || !ParseAttributes(attrs)) { return false; } - instruction = builder->AddInstruction(HloInstruction::CreateIota(shape)); + instruction = builder->AddInstruction( + HloInstruction::CreateIota(shape, *iota_dimension)); break; } // Unary ops. @@ -702,6 +796,27 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, HloInstruction::CreateAllToAll(shape, operands, replica_groups)); break; } + case HloOpcode::kCollectivePermute: { + optional>> source_targets; + attrs["source_target_pairs"] = { + /*required=*/true, AttrTy::kBracedInt64ListList, &source_targets}; + if (!ParseOperands(&operands, /*expected_size=*/1) || + !ParseAttributes(attrs)) { + return false; + } + std::vector> pairs(source_targets->size()); + for (int i = 0; i < pairs.size(); i++) { + if ((*source_targets)[i].size() != 2) { + return TokenError( + "expects 'source_target_pairs=' to be a list of pairs"); + } + pairs[i].first = (*source_targets)[i][0]; + pairs[i].second = (*source_targets)[i][1]; + } + instruction = builder->AddInstruction( + HloInstruction::CreateCollectivePermute(shape, operands[0], pairs)); + break; + } case HloOpcode::kReshape: { if (!ParseOperands(&operands, /*expected_size=*/1) || !ParseAttributes(attrs)) { @@ -724,8 +839,6 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kSort: { - auto loc = lexer_.GetLoc(); - optional> dimensions; attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, &dimensions}; @@ -733,20 +846,10 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, dimensions->size() != 1) { return false; } - switch (operands.size()) { - case 1: - instruction = builder->AddInstruction(HloInstruction::CreateSort( - shape, dimensions->at(0), /*keys=*/operands[0])); - break; - case 2: - instruction = builder->AddInstruction(HloInstruction::CreateSort( - shape, dimensions->at(0), - /*keys=*/operands[0], /*values=*/operands[1])); - break; - default: - return Error(loc, StrCat("expects either 1 or 2 operands, but has ", - operands.size(), " operands")); - } + instruction = builder->AddInstruction(HloInstruction::CreateSort( + shape, dimensions->at(0), + /*keys=*/operands[0], + /*values=*/absl::Span(operands).subspan(1))); break; } case HloOpcode::kTuple: { @@ -887,6 +990,9 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, AttrTy::kConvolutionDimensionNumbers, &dnums}; attrs["feature_group_count"] = {/*required=*/false, AttrTy::kInt64, &feature_group_count}; + optional> operand_precision; + attrs["operand_precision"] = {/*required=*/false, AttrTy::kPrecisionList, + &operand_precision}; if (!ParseOperands(&operands, /*expected_size=*/2) || !ParseAttributes(attrs)) { return false; @@ -897,9 +1003,17 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, if (!feature_group_count) { feature_group_count = 1; } + PrecisionConfig precision_config; + if (operand_precision) { + *precision_config.mutable_operand_precision() = { + operand_precision->begin(), operand_precision->end()}; + } else { + precision_config.mutable_operand_precision()->Resize( + operands.size(), PrecisionConfig::DEFAULT); + } instruction = builder->AddInstruction(HloInstruction::CreateConvolve( - shape, /*lhs=*/operands[0], /*rhs=*/operands[1], *window, *dnums, - feature_group_count.value())); + shape, /*lhs=*/operands[0], /*rhs=*/operands[1], + feature_group_count.value(), *window, *dnums, precision_config)); break; } case HloOpcode::kFft: { @@ -972,11 +1086,11 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, } instruction = builder->AddInstruction(HloInstruction::CreateReduce( shape, /*operands=*/ - tensorflow::gtl::ArraySlice(operands, 0, - operands.size() / 2), + absl::Span(operands).subspan( + 0, operands.size() / 2), /*init_values=*/ - tensorflow::gtl::ArraySlice( - operands, operands.size() / 2, operands.size()), + absl::Span(operands).subspan( + operands.size() / 2, operands.size()), *dimensions_to_reduce, *reduce_computation)); break; } @@ -1213,24 +1327,74 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, } case HloOpcode::kCustomCall: { optional custom_call_target; + optional opaque; optional window; optional dnums; + optional feature_group_count; + optional> operand_layout_constraints; attrs["custom_call_target"] = {/*required=*/true, AttrTy::kString, &custom_call_target}; + attrs["opaque"] = {/*required=*/false, AttrTy::kString, &opaque}; attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window}; attrs["dim_labels"] = {/*required=*/false, AttrTy::kConvolutionDimensionNumbers, &dnums}; + attrs["feature_group_count"] = {/*required=*/false, AttrTy::kInt64, + &feature_group_count}; + attrs["operand_layout_constraints"] = { + /*required=*/false, AttrTy::kShapeList, &operand_layout_constraints}; if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { return false; } - instruction = builder->AddInstruction(HloInstruction::CreateCustomCall( - shape, operands, *custom_call_target)); + if (operand_layout_constraints.has_value()) { + if (!LayoutUtil::HasLayout(shape)) { + return Error(lexer_.GetLoc(), + "Layout must be set on layout-constrained custom call"); + } + if (operands.size() != operand_layout_constraints->size()) { + return Error(lexer_.GetLoc(), + StrCat("Expected ", operands.size(), + " operand layout constraints, ", + operand_layout_constraints->size(), " given")); + } + for (int64 i = 0; i < operands.size(); ++i) { + const Shape& operand_shape_with_layout = + (*operand_layout_constraints)[i]; + if (!LayoutUtil::HasLayout(operand_shape_with_layout)) { + return Error(lexer_.GetLoc(), + StrCat("Operand layout constraint shape ", + ShapeUtil::HumanStringWithLayout( + operand_shape_with_layout), + " for operand ", i, " does not have a layout")); + } + if (!ShapeUtil::Compatible(operand_shape_with_layout, + operands[i]->shape())) { + return Error( + lexer_.GetLoc(), + StrCat( + "Operand layout constraint shape ", + ShapeUtil::HumanStringWithLayout(operand_shape_with_layout), + " for operand ", i, + " is not compatible with operand shape ", + ShapeUtil::HumanStringWithLayout(operands[i]->shape()))); + } + } + instruction = builder->AddInstruction(HloInstruction::CreateCustomCall( + shape, operands, *custom_call_target, *operand_layout_constraints, + opaque.has_value() ? *opaque : "")); + } else { + instruction = builder->AddInstruction(HloInstruction::CreateCustomCall( + shape, operands, *custom_call_target, + opaque.has_value() ? *opaque : "")); + } if (window.has_value()) { instruction->set_window(*window); } if (dnums.has_value()) { instruction->set_convolution_dimension_numbers(*dnums); } + if (feature_group_count.has_value()) { + instruction->set_feature_group_count(*feature_group_count); + } break; } case HloOpcode::kDot: { @@ -1246,6 +1410,9 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, optional> rhs_batch_dims; attrs["rhs_batch_dims"] = {/*required=*/false, AttrTy::kBracedInt64List, &rhs_batch_dims}; + optional> operand_precision; + attrs["operand_precision"] = {/*required=*/false, AttrTy::kPrecisionList, + &operand_precision}; if (!ParseOperands(&operands, /*expected_size=*/2) || !ParseAttributes(attrs)) { @@ -1270,8 +1437,17 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, rhs_batch_dims->end()}; } - instruction = builder->AddInstruction( - HloInstruction::CreateDot(shape, operands[0], operands[1], dnum)); + PrecisionConfig precision_config; + if (operand_precision) { + *precision_config.mutable_operand_precision() = { + operand_precision->begin(), operand_precision->end()}; + } else { + precision_config.mutable_operand_precision()->Resize( + operands.size(), PrecisionConfig::DEFAULT); + } + + instruction = builder->AddInstruction(HloInstruction::CreateDot( + shape, operands[0], operands[1], dnum, precision_config)); break; } case HloOpcode::kGather: { @@ -1388,12 +1564,6 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, if (backend_config) { instruction->set_raw_backend_config_string(std::move(*backend_config)); } - if (operand_precision) { - PrecisionConfigProto precision_config; - *precision_config.mutable_operand_precision() = {operand_precision->begin(), - operand_precision->end()}; - instruction->set_precision_config(precision_config); - } return AddInstruction(name, instruction, name_loc); } // NOLINT(readability/fn_size) @@ -1586,8 +1756,7 @@ bool HloParser::ParseInstructionNames( } std::pair* instr = FindInstruction(name); if (!instr) { - return TokenError( - Printf("instruction '%s' is not defined", name.c_str())); + return TokenError(StrFormat("instruction '%s' is not defined", name)); } instructions->push_back(instr->first); } while (EatIfPresent(TokKind::kComma)); @@ -1735,8 +1904,7 @@ bool HloParser::EatShapeAndCheckCompatible(const Shape& shape) { // literal // ::= tuple // ::= non_tuple -bool HloParser::ParseLiteral(std::unique_ptr* literal, - const Shape& shape) { +bool HloParser::ParseLiteral(Literal* literal, const Shape& shape) { return ShapeUtil::IsTuple(shape) ? ParseTupleLiteral(literal, shape) : ParseNonTupleLiteral(literal, shape); } @@ -1746,8 +1914,7 @@ bool HloParser::ParseLiteral(std::unique_ptr* literal, // literal_list // ::= /*empty*/ // ::= literal (',' literal)* -bool HloParser::ParseTupleLiteral(std::unique_ptr* literal, - const Shape& shape) { +bool HloParser::ParseTupleLiteral(Literal* literal, const Shape& shape) { if (!EatShapeAndCheckCompatible(shape)) { return TokenError(StrCat("expects tuple constant in shape ", ShapeUtil::HumanString(shape))); @@ -1755,8 +1922,7 @@ bool HloParser::ParseTupleLiteral(std::unique_ptr* literal, if (!ParseToken(TokKind::kLparen, "expects '(' in front of tuple elements")) { return false; } - std::vector> elements( - ShapeUtil::TupleElementCount(shape)); + std::vector elements(ShapeUtil::TupleElementCount(shape)); if (lexer_.GetKind() == TokKind::kRparen) { // empty @@ -1782,8 +1948,7 @@ bool HloParser::ParseTupleLiteral(std::unique_ptr* literal, // ::= rank01 // ::= rank2345 // rank2345 ::= shape sparse_or_nested_array -bool HloParser::ParseNonTupleLiteral(std::unique_ptr* literal, - const Shape& shape) { +bool HloParser::ParseNonTupleLiteral(Literal* literal, const Shape& shape) { if (LayoutUtil::IsSparseArray(shape)) { return ParseSparseLiteral(literal, shape); } @@ -1792,8 +1957,7 @@ bool HloParser::ParseNonTupleLiteral(std::unique_ptr* literal, return ParseDenseLiteral(literal, shape); } -bool HloParser::ParseDenseLiteral(std::unique_ptr* literal, - const Shape& shape) { +bool HloParser::ParseDenseLiteral(Literal* literal, const Shape& shape) { const tensorflow::int64 rank = ShapeUtil::Rank(shape); if (rank > 1 && !EatShapeAndCheckCompatible(shape)) { return false; @@ -1829,17 +1993,17 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr* literal, case TokKind::kLbrace: { nest_level++; if (nest_level > rank) { - return TokenError(Printf( - "expects nested array in rank %lld, but sees larger", rank)); + return TokenError(absl::StrFormat( + "expects nested array in rank %d, but sees larger", rank)); } if (nest_level > 1) { elems_seen_per_dim[nest_level - 2]++; if (elems_seen_per_dim[nest_level - 2] > shape.dimensions(nest_level - 2)) { - return TokenError(Printf( - "expects %lld elements in the %sth element, but sees more", + return TokenError(absl::StrFormat( + "expects %d elements in the %sth element, but sees more", shape.dimensions(nest_level - 2), - get_index_str(nest_level - 2).c_str())); + get_index_str(nest_level - 2))); } } lexer_.Lex(); @@ -1848,9 +2012,9 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr* literal, case TokKind::kRbrace: { nest_level--; if (elems_seen_per_dim[nest_level] != shape.dimensions(nest_level)) { - return TokenError(Printf( - "expects %lld elements in the %sth element, but sees %lld", - shape.dimensions(nest_level), get_index_str(nest_level).c_str(), + return TokenError(absl::StrFormat( + "expects %d elements in the %sth element, but sees %d", + shape.dimensions(nest_level), get_index_str(nest_level), elems_seen_per_dim[nest_level])); } elems_seen_per_dim[nest_level] = 0; @@ -1871,15 +2035,15 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr* literal, if (rank > 0) { if (nest_level != rank) { return TokenError( - Printf("expects nested array in rank %lld, but sees %lld", rank, - nest_level)); + absl::StrFormat("expects nested array in rank %d, but sees %d", + rank, nest_level)); } elems_seen_per_dim[rank - 1]++; if (elems_seen_per_dim[rank - 1] > shape.dimensions(rank - 1)) { - return TokenError( - Printf("expects %lld elements on the minor-most dimension, but " - "sees more", - shape.dimensions(rank - 1))); + return TokenError(absl::StrFormat( + "expects %d elements on the minor-most dimension, but " + "sees more", + shape.dimensions(rank - 1))); } } if (lexer_.GetKind() == TokKind::kw_true || @@ -1887,7 +2051,7 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr* literal, // TODO(congliu): bool type literals with rank >= 1 are actually // printed in a compact form instead of "true" or "false". Fix that. if (!SetValueInLiteral(lexer_.GetKind() == TokKind::kw_true, - linear_index++, literal->get())) { + linear_index++, literal)) { return false; } lexer_.Lex(); @@ -1898,7 +2062,7 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr* literal, return Error(loc, StrCat("expects integer for primitive type: ", PrimitiveType_Name(shape.element_type()))); } - if (!SetValueInLiteral(value, linear_index++, literal->get())) { + if (!SetValueInLiteral(value, linear_index++, literal)) { return false; } } else if (primitive_util::IsFloatingPointType(shape.element_type())) { @@ -1909,7 +2073,7 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr* literal, loc, StrCat("expect floating point value for primitive type: ", PrimitiveType_Name(shape.element_type()))); } - if (!SetValueInLiteral(value, linear_index++, literal->get())) { + if (!SetValueInLiteral(value, linear_index++, literal)) { return false; } } else { @@ -1921,12 +2085,11 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr* literal, } // end of switch } while (nest_level > 0); - *literal = (*literal)->Relayout(shape.layout()); + *literal = literal->Relayout(shape.layout()); return true; } -bool HloParser::ParseSparseLiteral(std::unique_ptr* literal, - const Shape& shape) { +bool HloParser::ParseSparseLiteral(Literal* literal, const Shape& shape) { if (!EatShapeAndCheckCompatible(shape)) { return false; } @@ -1966,13 +2129,12 @@ bool HloParser::ParseSparseLiteral(std::unique_ptr* literal, } template -bool HloParser::ParseSparseLiteralHelper(std::unique_ptr* literal, - const Shape& shape) { +bool HloParser::ParseSparseLiteralHelper(Literal* literal, const Shape& shape) { std::vector index; tensorflow::int64 rank = ShapeUtil::Rank(shape); - *literal = absl::make_unique(shape); + *literal = Literal(shape); if (!ParseToken(TokKind::kLbrace, "expects '{' at the beginning of a sparse literal")) { @@ -2046,7 +2208,7 @@ bool HloParser::ParseSparseLiteralHelper(std::unique_ptr* literal, return false; } - if ((*literal)->sparse_element_count() + 1 == + if (literal->sparse_element_count() + 1 == LayoutUtil::MaxSparseElements(shape.layout())) { return Error( lexer_.GetLoc(), @@ -2054,10 +2216,10 @@ bool HloParser::ParseSparseLiteralHelper(std::unique_ptr* literal, ShapeUtil::HumanStringWithLayout(shape))); } - (*literal)->AppendSparseElement(index, value); + literal->AppendSparseElement(index, value); } - (*literal)->SortSparseElements(); + literal->SortSparseElements(); return true; } @@ -2086,7 +2248,20 @@ bool HloParser::ParseOperands(std::vector* operands) { } } if (!ParseName(&name)) { - return false; + // When parsing a single instruction (as opposed to a whole module), an + // HLO may have one or more operands with a shape but no name: + // + // foo = add(f32[10], f32[10]) + // + // create_missing_instruction_ is always non-null when parsing a single + // instruction, and is responsible for creating kParameter instructions + // for these operands. + if (shape.has_value() && create_missing_instruction_ != nullptr && + scoped_name_tables_.size() == 1) { + name = ""; + } else { + return false; + } } std::pair* instruction = FindInstruction(name, shape); @@ -2135,8 +2310,8 @@ bool HloParser::ParseSubAttributes( for (const auto& attr_it : attrs) { if (attr_it.second.required && seen_attrs.find(attr_it.first) == seen_attrs.end()) { - return Error(loc, Printf("sub-attribute %s is expected but not seen", - attr_it.first.c_str())); + return Error(loc, StrFormat("sub-attribute %s is expected but not seen", + attr_it.first)); } } return ParseToken(TokKind::kRbrace, "expects '}' to end sub attributes"); @@ -2156,8 +2331,8 @@ bool HloParser::ParseAttributes( for (const auto& attr_it : attrs) { if (attr_it.second.required && seen_attrs.find(attr_it.first) == seen_attrs.end()) { - return Error(loc, Printf("attribute %s is expected but not seen", - attr_it.first.c_str())); + return Error(loc, StrFormat("attribute %s is expected but not seen", + attr_it.first)); } } return true; @@ -2173,7 +2348,7 @@ bool HloParser::ParseAttributeHelper( } VLOG(1) << "Parsing attribute " << name; if (!seen_attrs->insert(name).second) { - return Error(loc, Printf("attribute %s already exists", name.c_str())); + return Error(loc, StrFormat("attribute %s already exists", name)); } auto attr_it = attrs.find(name); if (attr_it == attrs.end()) { @@ -2188,8 +2363,8 @@ bool HloParser::ParseAttributeHelper( StrAppend(out, kv.first); })); } - return Error(loc, Printf("unexpected attribute \"%s\". %s", name.c_str(), - allowed_attrs.c_str())); + return Error(loc, StrFormat("unexpected attribute \"%s\". %s", name, + allowed_attrs)); } AttrTy attr_type = attr_it->second.attr_type; void* attr_out_ptr = attr_it->second.result; @@ -2239,9 +2414,17 @@ bool HloParser::ParseAttributeHelper( return true; } case AttrTy::kHloComputation: { - HloComputation* result; - if (!ParseComputationName(&result)) { - return false; + HloComputation* result = nullptr; + if (lexer_.GetKind() == TokKind::kLbrace) { + // This means it is a nested computation. + if (!ParseInstructionList(&result, /*computation_name=*/"_")) { + return false; + } + } else { + // This means it is a computation name. + if (!ParseComputationName(&result)) { + return false; + } } static_cast*>(attr_out_ptr)->emplace(result); return true; @@ -2372,19 +2555,28 @@ bool HloParser::ParseAttributeHelper( return ParseDomain(static_cast(attr_out_ptr)); } case AttrTy::kPrecisionList: { - std::vector result; + std::vector result; if (!ParsePrecisionList(&result)) { return false; } - static_cast>*>( + static_cast>*>( attr_out_ptr) ->emplace(result); return true; } + case AttrTy::kShapeList: { + std::vector result; + if (!ParseShapeList(&result)) { + return false; + } + static_cast>*>(attr_out_ptr) + ->emplace(result); + return true; + } } }(); if (!success) { - return Error(loc, Printf("error parsing attribute %s", name.c_str())); + return Error(loc, StrFormat("error parsing attribute %s", name)); } return true; } @@ -2548,7 +2740,7 @@ bool HloParser::ParseConvolutionDimensionNumbers( dnums->set_input_spatial_dimensions(c - '0', i); } else { return TokenError( - Printf("expects [0-%lldbf] in lhs dimension numbers", rank - 1)); + StrFormat("expects [0-%dbf] in lhs dimension numbers", rank - 1)); } } } @@ -2571,7 +2763,7 @@ bool HloParser::ParseConvolutionDimensionNumbers( dnums->set_kernel_spatial_dimensions(c - '0', i); } else { return TokenError( - Printf("expects [0-%lldio] in rhs dimension numbers", rank - 1)); + StrFormat("expects [0-%dio] in rhs dimension numbers", rank - 1)); } } } @@ -2593,8 +2785,8 @@ bool HloParser::ParseConvolutionDimensionNumbers( } else if (c < '0' + rank && c >= '0') { dnums->set_output_spatial_dimensions(c - '0', i); } else { - return TokenError( - Printf("expects [0-%lldbf] in output dimension numbers", rank - 1)); + return TokenError(StrFormat( + "expects [0-%dbf] in output dimension numbers", rank - 1)); } } } @@ -2640,9 +2832,10 @@ bool HloParser::ParseSliceRanges(SliceRanges* result) { } const auto& range = ranges.back(); if (range.size() != 2 && range.size() != 3) { - return Error(loc, Printf("expects [start:limit:step] or [start:limit], " - "but sees %ld elements.", - range.size())); + return Error(loc, + StrFormat("expects [start:limit:step] or [start:limit], " + "but sees %d elements.", + range.size())); } } while (EatIfPresent(TokKind::kComma)); @@ -2659,9 +2852,9 @@ bool HloParser::ParseSliceRanges(SliceRanges* result) { // ::= /*empty*/ // ::= precision_val (delim precision_val)* bool HloParser::ParsePrecisionList( - std::vector* result) { + std::vector* result) { auto parse_and_add_item = [&]() { - PrecisionConfigProto::Precision item; + PrecisionConfig::Precision item; if (!ParsePrecision(&item)) { return false; } @@ -2672,6 +2865,23 @@ bool HloParser::ParsePrecisionList( parse_and_add_item); } +// shapelist ::= '{' shapes '}' +// precision_elements +// ::= /*empty*/ +// ::= shape (',' shape)* +bool HloParser::ParseShapeList(std::vector* result) { + auto parse_and_add_item = [&]() { + Shape shape; + if (!ParseShape(&shape)) { + return false; + } + result->push_back(std::move(shape)); + return true; + }; + return ParseList(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma, + parse_and_add_item); +} + // int64list ::= start int64_elements end // int64_elements // ::= /*empty*/ @@ -2679,23 +2889,15 @@ bool HloParser::ParsePrecisionList( bool HloParser::ParseInt64List(const TokKind start, const TokKind end, const TokKind delim, std::vector* result) { - if (!ParseToken(start, StrCat("expects an int64 list starting with ", - TokKindToString(start)))) { - return false; - } - if (lexer_.GetKind() == end) { - // empty - } else { - do { - tensorflow::int64 i; - if (!ParseInt64(&i)) { - return false; - } - result->push_back(i); - } while (EatIfPresent(delim)); - } - return ParseToken( - end, StrCat("expects an int64 list to end with ", TokKindToString(end))); + auto parse_and_add_item = [&]() { + tensorflow::int64 i; + if (!ParseInt64(&i)) { + return false; + } + result->push_back(i); + return true; + }; + return ParseList(start, end, delim, parse_and_add_item); } bool HloParser::ParseList(const TokKind start, const TokKind end, @@ -2780,7 +2982,8 @@ bool HloParser::ParseShape(Shape* result) { } if (lexer_.GetKind() != TokKind::kShape) { - return TokenError("expects shape"); + return TokenError(absl::StrCat("expected shape, saw ", + TokKindToString(lexer_.GetKind()))); } *result = lexer_.GetShapeVal(); lexer_.Lex(); @@ -2828,14 +3031,13 @@ bool HloParser::ParseDxD(const string& name, std::vector* result) { LocTy loc = lexer_.GetLoc(); if (!result->empty()) { - return Error(loc, - Printf("sub-attribute '%s=' already exists", name.c_str())); + return Error(loc, StrFormat("sub-attribute '%s=' already exists", name)); } // 1D if (lexer_.GetKind() == TokKind::kInt) { tensorflow::int64 number; if (!ParseInt64(&number)) { - return Error(loc, Printf("expects sub-attribute '%s=i'", name.c_str())); + return Error(loc, StrFormat("expects sub-attribute '%s=i'", name)); } result->push_back(number); return true; @@ -2844,8 +3046,7 @@ bool HloParser::ParseDxD(const string& name, if (lexer_.GetKind() == TokKind::kDxD) { string str = lexer_.GetStrVal(); if (!SplitToInt64s(str, 'x', result)) { - return Error(loc, - Printf("expects sub-attribute '%s=ixj...'", name.c_str())); + return Error(loc, StrFormat("expects sub-attribute '%s=ixj...'", name)); } lexer_.Lex(); return true; @@ -2940,9 +3141,8 @@ bool HloParser::ParseOpcode(HloOpcode* result) { string val = lexer_.GetStrVal(); auto status_or_result = StringToHloOpcode(val); if (!status_or_result.ok()) { - return TokenError( - Printf("expects opcode but sees: %s, error: %s", val.c_str(), - status_or_result.status().error_message().c_str())); + return TokenError(StrFormat("expects opcode but sees: %s, error: %s", val, + status_or_result.status().error_message())); } *result = status_or_result.ValueOrDie(); lexer_.Lex(); @@ -2956,7 +3156,7 @@ bool HloParser::ParseFftType(FftType* result) { } string val = lexer_.GetStrVal(); if (!FftType_Parse(val, result) || !FftType_IsValid(*result)) { - return TokenError(Printf("expects fft type but sees: %s", val.c_str())); + return TokenError(StrFormat("expects fft type but sees: %s", val)); } lexer_.Lex(); return true; @@ -2970,9 +3170,9 @@ bool HloParser::ParseFusionKind(HloInstruction::FusionKind* result) { string val = lexer_.GetStrVal(); auto status_or_result = StringToFusionKind(val); if (!status_or_result.ok()) { - return TokenError( - Printf("expects fusion kind but sees: %s, error: %s", val.c_str(), - status_or_result.status().error_message().c_str())); + return TokenError(StrFormat("expects fusion kind but sees: %s, error: %s", + val, + status_or_result.status().error_message())); } *result = status_or_result.ValueOrDie(); lexer_.Lex(); @@ -2988,15 +3188,15 @@ bool HloParser::ParseRandomDistribution(RandomDistribution* result) { auto status_or_result = StringToRandomDistribution(val); if (!status_or_result.ok()) { return TokenError( - Printf("expects random distribution but sees: %s, error: %s", - val.c_str(), status_or_result.status().error_message().c_str())); + StrFormat("expects random distribution but sees: %s, error: %s", val, + status_or_result.status().error_message())); } *result = status_or_result.ValueOrDie(); lexer_.Lex(); return true; } -bool HloParser::ParsePrecision(PrecisionConfigProto::Precision* result) { +bool HloParser::ParsePrecision(PrecisionConfig::Precision* result) { VLOG(1) << "ParsePrecision"; if (lexer_.GetKind() != TokKind::kIdent) { return TokenError("expects random distribution"); @@ -3004,9 +3204,9 @@ bool HloParser::ParsePrecision(PrecisionConfigProto::Precision* result) { string val = lexer_.GetStrVal(); auto status_or_result = StringToPrecision(val); if (!status_or_result.ok()) { - return TokenError( - Printf("expects precision but sees: %s, error: %s", val.c_str(), - status_or_result.status().error_message().c_str())); + return TokenError(StrFormat("expects precision but sees: %s, error: %s", + val, + status_or_result.status().error_message())); } *result = status_or_result.ValueOrDie(); lexer_.Lex(); @@ -3076,7 +3276,7 @@ bool HloParser::EatIfPresent(TokKind kind) { bool HloParser::AddInstruction(const string& name, HloInstruction* instruction, LocTy name_loc) { - auto result = instruction_pool_.insert({name, {instruction, name_loc}}); + auto result = current_name_table().insert({name, {instruction, name_loc}}); if (!result.second) { Error(name_loc, StrCat("instruction already exists: ", name)); return Error(/*loc=*/result.first->second.second, @@ -3100,7 +3300,7 @@ StatusOr HloParser::ParseShardingOnly() { lexer_.Lex(); OpSharding op_sharding; if (!ParseSharding(&op_sharding)) { - return InvalidArgument("Syntax error:\n%s", GetError().c_str()); + return InvalidArgument("Syntax error:\n%s", GetError()); } if (lexer_.GetKind() != TokKind::kEof) { return InvalidArgument("Syntax error:\nExtra content after sharding"); @@ -3112,7 +3312,7 @@ StatusOr HloParser::ParseWindowOnly() { lexer_.Lex(); Window window; if (!ParseWindow(&window, /*expect_outer_curlies=*/false)) { - return InvalidArgument("Syntax error:\n%s", GetError().c_str()); + return InvalidArgument("Syntax error:\n%s", GetError()); } if (lexer_.GetKind() != TokKind::kEof) { return InvalidArgument("Syntax error:\nExtra content after window"); @@ -3125,7 +3325,7 @@ HloParser::ParseConvolutionDimensionNumbersOnly() { lexer_.Lex(); ConvolutionDimensionNumbers dnums; if (!ParseConvolutionDimensionNumbers(&dnums)) { - return InvalidArgument("Syntax error:\n%s", GetError().c_str()); + return InvalidArgument("Syntax error:\n%s", GetError()); } if (lexer_.GetKind() != TokKind::kEof) { return InvalidArgument( @@ -3134,86 +3334,109 @@ HloParser::ParseConvolutionDimensionNumbersOnly() { return dnums; } -Status HloParser::ParseSingleInstruction(HloComputation::Builder* builder, - string* root_name) { - TF_RET_CHECK(missing_instruction_hook_ == nullptr); +StatusOr HloParser::ParsePaddingConfigOnly() { + lexer_.Lex(); + PaddingConfig padding_config; + if (!ParsePaddingConfig(&padding_config)) { + return InvalidArgument("Syntax error:\n%s", GetError()); + } + if (lexer_.GetKind() != TokKind::kEof) { + return InvalidArgument("Syntax error:\nExtra content after PaddingConfig"); + } + return padding_config; +} + +bool HloParser::ParseSingleInstruction(HloModule* module) { + if (create_missing_instruction_ != nullptr || !scoped_name_tables_.empty()) { + LOG(FATAL) << "Parser state is not clean. Please do not call any other " + "methods before calling ParseSingleInstruction."; + } + HloComputation::Builder builder(module->name()); // The missing instruction hook we register creates the shaped instruction on // the fly as a parameter and returns it. int64 parameter_count = 0; - missing_instruction_hook_ = - [this, builder, ¶meter_count]( - string name, - const optional& shape) -> std::pair* { - if (!shape.has_value()) { - Error(lexer_.GetLoc(), - StrCat("Operand ", name, - " had no shape in HLO text; cannot create parameter for " - "single-instruction module.")); - return nullptr; - } - HloInstruction* parameter = builder->AddInstruction( - HloInstruction::CreateParameter(parameter_count++, *shape, name)); - instruction_pool_[name] = {parameter, lexer_.GetLoc()}; - return tensorflow::gtl::FindOrNull(instruction_pool_, name); + create_missing_instruction_ = + [this, &builder, ¶meter_count]( + const string& name, + const Shape& shape) -> std::pair* { + string new_name = name.empty() ? StrCat("_", parameter_count) : name; + HloInstruction* parameter = builder.AddInstruction( + HloInstruction::CreateParameter(parameter_count++, shape, new_name)); + current_name_table()[new_name] = {parameter, lexer_.GetLoc()}; + return tensorflow::gtl::FindOrNull(current_name_table(), new_name); }; - // Prime the lexer. - lexer_.Lex(); - // Parse the instruction with the registered hook. - if (!ParseInstruction(builder, root_name)) { - return InvalidArgument("Syntax error:\n%s", GetError().c_str()); + Scope scope(&scoped_name_tables_); + if (CanBeShape()) { + // This means that the instruction's left-hand side is probably omitted, + // e.g. + // + // f32[10] fusion(...), calls={...} + if (!ParseInstruciontRhs(&builder, module->name(), lexer_.GetLoc())) { + return false; + } + } else { + // This means that the instruction's left-hand side might exist, e.g. + // + // foo = f32[10] fusion(...), calls={...} + string root_name; + if (!ParseInstruction(&builder, &root_name)) { + return false; + } } - return Status::OK(); + + module->AddEntryComputation(builder.Build()); + for (auto& comp : computations_) { + module->AddEmbeddedComputation(std::move(comp)); + } + return true; } } // namespace StatusOr> ParseHloString( absl::string_view str, const HloModuleConfig& config) { - HloParser parser(str, config); - if (!parser.Run()) { - return InvalidArgument("Syntax error:\n%s", parser.GetError().c_str()); - } - return parser.ConsumeHloModule(); + auto module = absl::make_unique(/*name=*/"_", config); + HloParser parser(str); + TF_RETURN_IF_ERROR(parser.Run(module.get())); + return std::move(module); } StatusOr> ParseHloString(absl::string_view str) { - HloModuleConfig config; - return ParseHloString(str, config); + auto module = absl::make_unique(/*name=*/"_", HloModuleConfig()); + HloParser parser(str); + TF_RETURN_IF_ERROR(parser.Run(module.get())); + return std::move(module); } -StatusOr> ParseHloOpToModule( - absl::string_view str, absl::string_view name) { - HloModuleConfig config; - HloParser parser(str, config); - auto builder = absl::make_unique(string(name)); - string root_name; - TF_RETURN_IF_ERROR(parser.ParseSingleInstruction(builder.get(), &root_name)); - std::unique_ptr computation = builder->Build(); - auto module = absl::make_unique(string(name), config); - module->AddEntryComputation(std::move(computation)); - return std::move(module); +Status ParseHloString(absl::string_view str, HloModule* module) { + TF_RET_CHECK(module->computation_count() == 0); + HloParser parser(str); + TF_RETURN_IF_ERROR(parser.Run(module)); + return Status::OK(); } StatusOr ParseSharding(absl::string_view str) { - HloModuleConfig config; - HloParser parser(str, config); + HloParser parser(str); return parser.ParseShardingOnly(); } StatusOr ParseWindow(absl::string_view str) { - HloModuleConfig config; - HloParser parser(str, config); + HloParser parser(str); return parser.ParseWindowOnly(); } StatusOr ParseConvolutionDimensionNumbers( absl::string_view str) { - HloModuleConfig config; - HloParser parser(str, config); + HloParser parser(str); return parser.ParseConvolutionDimensionNumbersOnly(); } +StatusOr ParsePaddingConfig(absl::string_view str) { + HloParser parser(str); + return parser.ParsePaddingConfigOnly(); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_parser.h b/tensorflow/compiler/xla/service/hlo_parser.h index 0c64b50481bf2e86a2c588fbf2d77226c8428b7c..81eeb9f13bf7f06123c0b35e9f3352c197866a7a 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.h +++ b/tensorflow/compiler/xla/service/hlo_parser.h @@ -30,18 +30,18 @@ namespace xla { // For details about the syntax accepted by this parser, see // g3doc/hlo_parser.md. -// The api of the hlo parser. Given a string in the HloModule::ToString() -// format, parses the string and creates a HloModule with the given config. +// Given a string in the HloModule::ToString() format, parses the string and +// creates a HloModule with the given config. StatusOr> ParseHloString( absl::string_view str, const HloModuleConfig& config); -// Parses the text for a single HLO operation into an HLO module with a function -// that runs that operation (with the same parameters) as its entry computation. -StatusOr> ParseHloOpToModule( - absl::string_view str, absl::string_view name = "single_op"); +// Given a string in the HloModule::ToString() format, parses the string and +// builds the HloModule in place at the given module pointer. 'module' must +// point to an empty module (no computations). +Status ParseHloString(absl::string_view str, HloModule* module); -// The api of the hlo parser. Given a string in the HloModule::ToString() -// format, parses the string and creates a HloModule with default config. +// Given a string in the HloModule::ToString() format, parses the string and +// creates a HloModule with default config. StatusOr> ParseHloString(absl::string_view str); // Parses the result of HloSharding::ToString(), e.g. "{replicated}". @@ -59,6 +59,9 @@ StatusOr ParseConvolutionDimensionNumbers( // sharding, i.e. just the rhs of the "sharding={...}" attribute string. StatusOr ParseSharding(absl::string_view str); +// Parses the result of PaddingConfigToString(), e.g. "0_0x1_1". +StatusOr ParsePaddingConfig(absl::string_view str); + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PARSER_H_ diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc index b3d3ccda743b998e478daf678d2b417061212754..19f84d8bd28371518e44e38614b8a81fa920985f 100644 --- a/tensorflow/compiler/xla/service/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -19,6 +19,8 @@ limitations under the License. #include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -382,7 +384,7 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2 %input = f32[1,2,1]{2,1,0} parameter(0) %copy = f32[1,2,1]{2,0,1} copy(f32[1,2,1]{2,1,0} %input) %filter = f32[1,1,1]{2,1,0} parameter(1) - ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), window={size=1}, dim_labels=b0f_0io->b0f, feature_group_count=1 + ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), window={size=1}, dim_labels=b0f_0io->b0f, operand_precision={high,default} } )" @@ -395,7 +397,7 @@ R"(HloModule ConvolveR2_module ENTRY %ConvolveR2.v3 (input: f32[1,2], filter: f32[1,1]) -> f32[1,2] { %input = f32[1,2]{1,0} parameter(0) %filter = f32[1,1]{1,0} parameter(1) - ROOT %convolution = f32[1,2]{0,1} convolution(f32[1,2]{1,0} %input, f32[1,1]{1,0} %filter), dim_labels=bf_io->bf, feature_group_count=1 + ROOT %convolution = f32[1,2]{0,1} convolution(f32[1,2]{1,0} %input, f32[1,1]{1,0} %filter), dim_labels=bf_io->bf } )" @@ -408,7 +410,7 @@ R"(HloModule ConvolveBackward_module ENTRY %ConvolveBackward (input: f32[128,7,7,512], filter: f32[3,3,512,512]) -> f32[128,14,14,512] { %input = f32[128,7,7,512]{0,3,2,1} parameter(0) %filter = f32[3,3,512,512]{3,2,1,0} parameter(1) - ROOT %convolution-base-dilated = f32[128,14,14,512]{0,3,2,1} convolution(f32[128,7,7,512]{0,3,2,1} %input, f32[3,3,512,512]{3,2,1,0} %filter), window={size=3x3 pad=1_2x1_2 lhs_dilate=2x2 rhs_reversal=1x1}, dim_labels=b01f_01oi->b01f, feature_group_count=1 + ROOT %convolution-base-dilated = f32[128,14,14,512]{0,3,2,1} convolution(f32[128,7,7,512]{0,3,2,1} %input, f32[3,3,512,512]{3,2,1,0} %filter), window={size=3x3 pad=1_2x1_2 lhs_dilate=2x2 rhs_reversal=1x1}, dim_labels=b01f_01oi->b01f } )" @@ -800,6 +802,43 @@ ENTRY %ConstantUnsignedNoOverflow () -> u64[] { ROOT %constant = u64[] constant(9223372036854775807) } +)" +}, +// CustomCallWithLayoutConstraints +{ +"CustomCallWithLayoutConstraints", +R"(HloModule CustomCallWithLayoutConstraints + +ENTRY %CustomCallWithLayoutConstraints (p0: f32[42,2,3], p1: f32[123,4]) -> f32[1,2,3] { + %p0 = f32[42,2,3]{0,1,2} parameter(0) + %p1 = f32[123,4]{0,1} parameter(1) + ROOT %custom-call = f32[1,2,3]{0,2,1} custom-call(f32[42,2,3]{0,1,2} %p0, f32[123,4]{0,1} %p1), custom_call_target="baz", operand_layout_constraints={f32[42,2,3]{0,1,2}, f32[123,4]{1,0}} +} + +)" +}, +// CustomCallWithLayoutConstraintsNoOperands +{ +"CustomCallWithLayoutConstraintsNoOperands", +R"(HloModule CustomCallWithLayoutConstraintsNoOperands + +ENTRY %CustomCallWithLayoutConstraints () -> f32[1,2,3] { + ROOT %custom-call = f32[1,2,3]{0,2,1} custom-call(), custom_call_target="baz", operand_layout_constraints={} +} + +)" +}, +// CustomCallWithLayoutConstraintsTupleShapes +{ +"CustomCallWithLayoutConstraintsTupleShapes", +R"(HloModule CustomCallWithLayoutConstraintsTupleShapes + +ENTRY %CustomCallWithLayoutConstraints (p0: (f32[2,2], f32[42,2,3]), p1: f32[123,4]) -> (f32[1,2,3], f32[1,2,3]) { + %p0 = (f32[2,2]{0,1}, f32[42,2,3]{0,1,2}) parameter(0) + %p1 = f32[123,4]{0,1} parameter(1) + ROOT %custom-call = (f32[1,2,3]{0,2,1}, f32[1,2,3]{1,2,0}) custom-call((f32[2,2]{0,1}, f32[42,2,3]{0,1,2}) %p0, f32[123,4]{0,1} %p1), custom_call_target="baz", operand_layout_constraints={(f32[2,2]{1,0}, f32[42,2,3]{2,0,1}), f32[123,4]{1,0}} +} + )" }, }); @@ -964,6 +1003,21 @@ ENTRY Sort { ROOT sorted = (f32[1024,16]{0,1}, s32[1024,16]{0,1}) sort(keys, values), dimensions={0} } +)" +}, +// Sort (Key, Value, Value, Value) +{ +"SortManyValues", +R"(HloModule sort + +ENTRY Sort { + keys = f32[1024,16]{0,1} parameter(0) + values.0 = s32[1024,16]{0,1} parameter(1) + values.1 = u32[1024,16]{0,1} parameter(2) + values.2 = f32[1024,16]{0,1} parameter(3) + ROOT sorted = (f32[1024,16]{0,1}, s32[1024,16]{0,1}, u32[1024,16]{0,1}, f32[1024,16]{0,1}) sort(keys, values.0, values.1, values.2), dimensions={0} +} + )" }, // Conditional @@ -1000,6 +1054,18 @@ ENTRY CustomCall { ROOT custom-call = f32[1,2,3]{0,2,1} custom-call(constant), custom_call_target="foo\"bar" } +)" +}, +// CustomCall with opaque value. +{ +"CustomCallWithOpaque", +R"(HloModule custom_call + +ENTRY CustomCall { + constant = f32[1]{0} constant({12345}) + ROOT custom-call = f32[1,2,3]{0,2,1} custom-call(constant), custom_call_target="foo\"bar", opaque="this string is opaque" +} + )" }, // Variables with non-default names @@ -1096,6 +1162,18 @@ ENTRY AllToAllWithSubgroups { ROOT a2a = f32[128,32]{0,1} all-to-all(input), replica_groups={{1,2},{3,0}} } +)" +}, +// collective-permute +{ +"CollectivePermute", +R"(HloModule CollectivePermute + +ENTRY CollectivePermute { + input = f32[128,32]{0,1} parameter(0) + ROOT root = f32[128,32]{0,1} collective-permute(input), source_target_pairs={{0,1},{1,2},{2,3}} +} + )" }, // Iota @@ -1104,69 +1182,113 @@ ENTRY AllToAllWithSubgroups { R"(HloModule iota ENTRY Iota { - ROOT iota = f32[100]{0} iota() + ROOT iota = f32[100]{0} iota(), iota_dimension=0 } )" }, -// custom-call with window and dim_labels +// custom-call with window, dim_labels and feature_group_count { -"CustomCallWithWindowAndDimLabels", -R"(HloModule CustomCallWithWindowAndDimLabels +"CustomCallWithWindowAndDimLabelsAndFeatureGroupCount", +R"(HloModule CustomCallWithWindowAndDimLabelsAndFeatureGroupCount ENTRY Computation { - ROOT r = f32[100]{0} custom-call(), window={size=2x2}, dim_labels=b01f_01io->b01f, custom_call_target="target" + ROOT r = f32[100]{0} custom-call(), window={size=2x2}, dim_labels=b01f_01io->b01f, feature_group_count=2, custom_call_target="target" } )" + }, +// is_scheduled=true attribute +{ +"ScheduledModule", +R"(HloModule scheduled_module, is_scheduled=true + +ENTRY Sort { + keys = f32[1024]{0} parameter(0) + values = s32[1024]{0} parameter(1) + ROOT sorted = (f32[1024]{0}, s32[1024]{0}) sort(keys, values), dimensions={0} } - }); + +)" +} +}); // clang-format on } -class HloParserTest : public ::testing::Test, - public ::testing::WithParamInterface { +// The test class for those tests defined above which round-trip through the +// parser and ToString is templatized on two bool parameters: +// +// short_form : used for the "short" test cases which use the ShortParsable +// output form. +// proto_round_trip : whether the module should also be round-tripped through +// HloProto form. This provides much better coverage for the proto +// serialization/deserialization. +// +// The proto_round_trip=true case also technically covers the Parser->ToString +// roundtrip as well, but separating out the Parser->ToString roundtrip as its +// own test provides better isolation and could conceivably catch weirdo bugs +// which are hidden by interaction between the textual and proto roundtripping. +template +class HloParameterizedParserTest + : public ::testing::Test, + public ::testing::WithParamInterface { protected: - static void ExpectHasSubstr(string_view s, string_view expected) { - EXPECT_TRUE(absl::StrContains(s, expected)) - << "'" << s << "' does not contain '" << expected << "'"; - } - // Expects "ToString(ParseHloString(string)) == string", that is, parses the // string, asserts that it succeeded, stringifies the parsed module, and // checks that the it equals the original string. void ExpectEqual() { const string& original = GetParam().module_string; - auto result = ParseHloString(original); - TF_ASSERT_OK(result.status()); - EXPECT_EQ(original, result.ValueOrDie()->ToString( - HloPrintOptions().set_print_large_constants(true))); - } -}; - -class HloParserShortTest : public HloParserTest { - protected: - void ExpectEqualShort() { - const string& original = GetParam().module_string; - auto result = ParseHloString(original); - TF_ASSERT_OK(result.status()); - EXPECT_EQ(original, - result.ValueOrDie()->ToString(HloPrintOptions::ShortParsable())); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(original)); + if (proto_round_trip) { + TF_ASSERT_OK_AND_ASSIGN(module, HloModule::CreateFromProto( + module->ToProto(), module->config())); + } + if (short_form) { + EXPECT_EQ(original, module->ToString(HloPrintOptions::ShortParsable())); + } else { + EXPECT_EQ( + original, + module->ToString(HloPrintOptions().set_print_large_constants(true))); + } } }; -TEST_P(HloParserTest, Run) { ExpectEqual(); } +// These using shenanigans are required because the TEST_P macro doesn't like +// template instantiations which contain commas. +using HloParserTestLong = HloParameterizedParserTest; +using HloParserTestLongProto = HloParameterizedParserTest; +using HloParserTestShort = HloParameterizedParserTest; +using HloParserTestShortProto = HloParameterizedParserTest; -TEST_P(HloParserShortTest, Run) { ExpectEqualShort(); } +TEST_P(HloParserTestLong, Run) { ExpectEqual(); } +TEST_P(HloParserTestLongProto, Run) { ExpectEqual(); } +TEST_P(HloParserTestShort, Run) { ExpectEqual(); } +TEST_P(HloParserTestShortProto, Run) { ExpectEqual(); } -INSTANTIATE_TEST_CASE_P(HloParserTestSuccessInstantiation, HloParserTest, +INSTANTIATE_TEST_CASE_P(HloParserTestSuccessInstantiation, HloParserTestLong, ::testing::ValuesIn(CreateTestCases()), TestDataToString); - -INSTANTIATE_TEST_CASE_P(HloParserTestSuccessInstantiation, HloParserShortTest, +INSTANTIATE_TEST_CASE_P(HloParserTestSuccessInstantiation, + HloParserTestLongProto, + ::testing::ValuesIn(CreateTestCases()), + TestDataToString); +INSTANTIATE_TEST_CASE_P(HloParserTestSuccessInstantiation, HloParserTestShort, + ::testing::ValuesIn(CreateShortTestCases()), + TestDataToString); +INSTANTIATE_TEST_CASE_P(HloParserTestSuccessInstantiation, + HloParserTestShortProto, ::testing::ValuesIn(CreateShortTestCases()), TestDataToString); +class HloParserTest : public ::testing::Test { + protected: + static void ExpectHasSubstr(string_view s, string_view expected) { + EXPECT_TRUE(absl::StrContains(s, expected)) + << "'" << s << "' does not contain '" << expected << "'"; + } +}; + TEST_F(HloParserTest, Empty) { const string original = ""; auto result = ParseHloString(original); @@ -1234,7 +1356,7 @@ TEST_F(HloParserTest, MoreConstants) { ENTRY %SelectScalarS32True.v4 () -> s32[] { %constant.2 = pred[] constant(true) - %constant.1 = s32[] constant(-42), sharding={s32[5,6] devices=[2,3]1,2,3,4} + %constant.1 = s32[] constant(-42), sharding={s32[5,6] devices=[2,2]1,2,3,4} %constant = s32[] constant(42) %select = s32[] select(pred[] %constant.2, s32[] %constant.1, s32[] %constant) } @@ -1693,6 +1815,25 @@ ENTRY entry { "was parsing 8:39: error: instruction does not exist: aparam"); } +TEST_F(HloParserTest, SameNameDiffComputations) { + const string original = R"(HloModule same_names: +add { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT result = f32[] add(p0, p1) +} + +ENTRY ReduceR3ToR2 { + p0 = f32[8,16,256]{2,1,0} parameter(0) + p1 = f32[] constant(0) + ROOT result = f32[8,16]{1,0} reduce(p0, p1), dimensions={2}, to_apply=add +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(original)); + ASSERT_NE(module->entry_computation(), nullptr); + EXPECT_THAT(module->entry_computation()->root_instruction(), op::Reduce()); +} + TEST_F(HloParserTest, ParseSharding) { const string original = "{maximal device=42}"; TF_ASSERT_OK_AND_ASSIGN(HloSharding sharding, ParseSharding(original)); @@ -1713,6 +1854,25 @@ TEST_F(HloParserTest, ParseConvolutionDimensionNumbers) { EXPECT_EQ(original, ConvolutionDimensionNumbersToString(dnums)); } +TEST_F(HloParserTest, ParsePaddingConfigNoInteriorPadding) { + const string original = "0_1x2_3"; + TF_ASSERT_OK_AND_ASSIGN(PaddingConfig dnums, ParsePaddingConfig(original)); + EXPECT_EQ(original, PaddingConfigToString(dnums)); +} + +TEST_F(HloParserTest, ParsePaddingConfigInteriorPadding) { + const string original = "0_1_0x2_3_4"; + TF_ASSERT_OK_AND_ASSIGN(PaddingConfig dnums, ParsePaddingConfig(original)); + EXPECT_EQ(original, PaddingConfigToString(dnums)); +} + +TEST_F(HloParserTest, ParsePaddingConfigInteriorPaddingImplicitZeroDim) { + TF_ASSERT_OK_AND_ASSIGN(PaddingConfig dnums, ParsePaddingConfig("0_1x2_3_4")); + // The extra "_0" gets added to the canonical string because the other dim has + // interior padding. + EXPECT_EQ("0_1_0x2_3_4", PaddingConfigToString(dnums)); +} + TEST_F(HloParserTest, NontupleInfeed) { const string original = R"(HloModule nontuple_infeed: ENTRY nontuple_infeed { @@ -1727,22 +1887,281 @@ TEST(HloParserSingleOpTest, SingleOp) { const string text = "%multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, " "f32[2,4]{1,0} %x)"; - TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloOpToModule(text)); + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(text)); const HloComputation* computation = module->entry_computation(); ASSERT_NE(computation, nullptr); EXPECT_THAT(computation->root_instruction(), op::Multiply(op::Parameter(0), op::Parameter(1))); } -TEST(HloParserSingleOpTest, SingleOpNoShapesProducesError) { +TEST(HloParserSingleOpTest, SingleOpNoShapeProducesError) { + const string text = "multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x)"; + StatusOr> module = ParseHloString(text); + ASSERT_TRUE(!module.status().ok()); + LOG(INFO) << "Status: " << module.status(); + EXPECT_THAT(module.status().ToString(), + ::testing::HasSubstr("expects '=' in instruction")); +} + +TEST(HloParserSingleOpTest, SingleOpNoOperandShapesProducesError) { const string text = "%multiply = f32[2,4]{1,0} multiply(%broadcast, %x)"; - StatusOr> module = ParseHloOpToModule(text); + StatusOr> module = ParseHloString(text); ASSERT_TRUE(!module.status().ok()); LOG(INFO) << "Status: " << module.status(); + EXPECT_THAT(module.status().ToString(), + ::testing::HasSubstr("Operand had no shape in HLO text")); +} + +TEST(HloParserSingleOpTest, SingleOpNoNames) { + const string text = + "%multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0}, f32[2,4]{1,0})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(text)); + const HloComputation* computation = module->entry_computation(); + ASSERT_NE(computation, nullptr); + EXPECT_THAT(computation->root_instruction(), + op::Multiply(op::Parameter(0), op::Parameter(1))); +} + +TEST(HloParserSingleOpTest, CanonicalOp) { + const string text = "f32[2,4]{1,0} multiply(f32[2,4]{1,0}, f32[2,4]{1,0})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(text)); + const HloComputation* computation = module->entry_computation(); + ASSERT_NE(computation, nullptr); + EXPECT_THAT(computation->root_instruction(), + op::Multiply(op::Parameter(0), op::Parameter(1))); + EXPECT_EQ( + computation->root_instruction()->ToString(HloPrintOptions::Canonical()), + text); +} + +TEST(HloParserSingleOpTest, CanonicalOpWithNested) { + const string text = + R"(f32[5,20]{1,0} while(f32[5,10]{1,0}), condition= +{ + tmp_0 = f32[5,10]{1,0} parameter(0) + tmp_1 = f32[20,10]{1,0} parameter(1) + ROOT tmp_2 = f32[5,20]{1,0} fusion(f32[5,10]{1,0} tmp_0, f32[20,10]{1,0} tmp_1), kind=kLoop, calls= + { + tmp_0 = f32[5,10]{1,0} parameter(0) + tmp_1 = f32[20,10]{1,0} parameter(1) + tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0} + ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0} + } +}, body= +{ + tmp_0 = f32[5,10]{1,0} parameter(0) + tmp_1 = f32[20,10]{1,0} parameter(1) + ROOT tmp_2 = f32[5,20]{1,0} fusion(f32[5,10]{1,0} tmp_0, f32[20,10]{1,0} tmp_1), kind=kLoop, calls= + { + tmp_0 = f32[5,10]{1,0} parameter(0) + tmp_1 = f32[20,10]{1,0} parameter(1) + tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0} + ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0} + } +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(text)); + const HloComputation* computation = module->entry_computation(); + ASSERT_NE(computation, nullptr); + EXPECT_EQ( + computation->root_instruction()->ToString(HloPrintOptions::Canonical()), + text); +} + +TEST(HloParserSingleOpTest, SingleOpWithNested) { + const string text = + R"(%fusion = f32[3,2,1,1]{3,2,1,0} fusion(f32[3,2,1,1]{3,2,1,0} %p0, f32[2]{0} %p1), kind=kLoop, calls= +{ + %param_0 = f32[3,2,1,1]{3,2,1,0} parameter(0) + %param_1 = f32[2]{0} parameter(1) + %broadcast = f32[3,2,1,1]{3,2,1,0} broadcast(f32[2]{0} %param_1), dimensions={1} + ROOT %subtract = f32[3,2,1,1]{3,2,1,0} subtract(f32[3,2,1,1]{3,2,1,0} %param_0, f32[3,2,1,1]{3,2,1,0} %broadcast) +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(text)); + const HloComputation* computation = module->entry_computation(); + ASSERT_NE(computation, nullptr); + EXPECT_THAT(computation->root_instruction(), + op::Fusion(op::Parameter(0), op::Parameter(1))); +} + +TEST(HloParserSingleOpTest, SingleOpWithNested_DoesNotExist) { + const string text = + R"(reduce = f32[] reduce(f32[10], f32[]), dimensions={1}, to_apply= +{ + result = f32[] add(f32[] x, f32[] y) +})"; + auto status = ParseHloString(text).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), + ::testing::HasSubstr("does not exist: x")); +} + +TEST(HloParserSingleOpTest, SingleOpWithNested_NoLhs) { + const string text = + R"(reduce = f32[] reduce(f32[10], f32[]), dimensions={1}, to_apply= +{ + f32[] add(f32[] x, f32[] y) +})"; + auto status = ParseHloString(text).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), ::testing::HasSubstr("expects name")); +} + +TEST(HloParserSingleOpTest, SingleOpWithNested_NoOperandName) { + const string text = + R"(reduce = f32[] reduce(f32[10], f32[]), dimensions={1}, to_apply= +{ + result = f32[] add(f32[], f32[]) +})"; + auto status = ParseHloString(text).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), ::testing::HasSubstr("expects name")); +} + +TEST(HloParserSingleOpTest, ConvolutionTrivialFeatureGroupCount) { + const string text = + R"(%convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), window={size=1}, dim_labels=b0f_0io->b0f)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(text)); + const HloComputation* computation = module->entry_computation(); + ASSERT_NE(computation, nullptr); + EXPECT_THAT(computation->root_instruction(), + op::Convolution(op::Parameter(0), op::Parameter(1))); + auto* convolution = + Cast(computation->root_instruction()); + EXPECT_EQ(convolution->feature_group_count(), 1); +} + +TEST_F(HloParserTest, IsScheduledIsFalse) { + const string text = R"( +HloModule axpy_module, is_scheduled=false + +ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] { + %alpha = f32[] parameter(0) + %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={} + %x = f32[2,4]{1,0} parameter(1) + %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x) + %y = f32[2,4]{1,0} parameter(2) + ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(text)); + ASSERT_FALSE(module->has_schedule()); +} + +TEST_F(HloParserTest, IsScheduledNotPresent) { + const string text = R"( +HloModule axpy_module + +ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] { + %alpha = f32[] parameter(0) + %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={} + %x = f32[2,4]{1,0} parameter(1) + %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x) + %y = f32[2,4]{1,0} parameter(2) + ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(text)); + ASSERT_FALSE(module->has_schedule()); +} + +TEST_F(HloParserTest, IsScheduledIsTrue) { + const string text = R"( +HloModule axpy_module, is_scheduled=true + +ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] { + %alpha = f32[] parameter(0) + %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={} + %x = f32[2,4]{1,0} parameter(1) + %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x) + %y = f32[2,4]{1,0} parameter(2) + ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(text)); + ASSERT_TRUE(module->has_schedule()); + TF_ASSERT_OK(module->schedule().Verify()); + EXPECT_EQ(module->schedule().sequences().size(), 1); + ASSERT_TRUE( + module->schedule().is_computation_scheduled(module->entry_computation())); + EXPECT_THAT( + module->schedule().sequence(module->entry_computation()).instructions(), + ::testing::ElementsAre(op::Parameter(), op::Broadcast(), op::Parameter(), + op::Multiply(), op::Parameter(), op::Add())); +} + +TEST_F(HloParserTest, IsScheduledIsTrueDifferentOrder) { + // As above but in with a different schedule order. + const string text = R"( +HloModule axpy_module, is_scheduled=true + +ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] { + %alpha = f32[] parameter(0) + %x = f32[2,4]{1,0} parameter(1) + %y = f32[2,4]{1,0} parameter(2) + %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={} + %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x) + ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(text)); + ASSERT_TRUE(module->has_schedule()); + TF_ASSERT_OK(module->schedule().Verify()); + EXPECT_EQ(module->schedule().sequences().size(), 1); + ASSERT_TRUE( + module->schedule().is_computation_scheduled(module->entry_computation())); EXPECT_THAT( - module.status().ToString(), - ::testing::HasSubstr("Operand broadcast had no shape in HLO text")); + module->schedule().sequence(module->entry_computation()).instructions(), + ::testing::ElementsAre(op::Parameter(), op::Parameter(), op::Parameter(), + op::Broadcast(), op::Multiply(), op::Add())); +} + +TEST_F(HloParserTest, CustomCallWrongNumberofOperandConstraints) { + const string original = R"(HloModule CustomCallWrongNumberofOperandConstraints + +ENTRY %CustomCallWrongNumberofOperandConstraints (p0: f32[42,2,3], p1: f32[123,4]) -> f32[1,2,3] { + %p0 = f32[42,2,3]{0,1,2} parameter(0) + %p1 = f32[123,4]{0,1} parameter(1) + ROOT %custom-call = f32[1,2,3]{0,1,2} custom-call(f32[42,2,3]{0,1,2} %p0, f32[123,4]{0,1} %p1), custom_call_target="baz", operand_layout_constraints={f32[42,2,3]{0,1,2}} } +)"; + ExpectHasSubstr(ParseHloString(original).status().error_message(), + "Expected 2 operand layout constraints, 1 given"); +} + +TEST_F(HloParserTest, CustomCallIncompatibleOperandConstraints) { + const string original = R"(HloModule CustomCallIncompatibleOperandConstraints + +ENTRY %CustomCallIncompatibleOperandConstraints (p0: f32[42,2,3], p1: f32[123,4]) -> f32[1,2,3] { + %p0 = f32[42,2,3]{0,1,2} parameter(0) + %p1 = f32[123,4]{0,1} parameter(1) + ROOT %custom-call = f32[1,2,3]{0,1,2} custom-call(f32[42,2,3]{0,1,2} %p0, f32[123,4]{0,1} %p1), custom_call_target="baz", operand_layout_constraints={f32[42,2,3]{0,1,2}, f32[555,5]{1,0}} +} + +)"; + ExpectHasSubstr(ParseHloString(original).status().error_message(), + "operand 1 is not compatible with operand shape"); +} + +TEST_F(HloParserTest, AllowShapeWhitespace) { + const string text = R"( +HloModule module + +ENTRY entry { + ROOT root = f32[ 1, 2,3, 4, 5]{0, 1, 2,3, 4 } parameter(0) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(text)); +} + +// custom call incompatible shape. + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_pass_interface.h b/tensorflow/compiler/xla/service/hlo_pass_interface.h index f1ad0f9b0148cb3d5f938e7f5d220d6cb82ea98d..fdaac34386c5135d6bbeb372d7a9199344836c8d 100644 --- a/tensorflow/compiler/xla/service/hlo_pass_interface.h +++ b/tensorflow/compiler/xla/service/hlo_pass_interface.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PASS_INTERFACE_H_ #include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_module_group.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" @@ -25,15 +26,45 @@ limitations under the License. namespace xla { // Base class for HLO passes. These are used with the HloPassPipeline to -// organize a sequence of passes. +// organize a sequence of passes. An HLO pass should not extend this class +// directly; it should extend HloModulePass or HloModuleGroupPass. class HloPassInterface { public: virtual ~HloPassInterface() = default; virtual absl::string_view name() const = 0; - // Run the pass on the given HLO module. Return whether it modified the + // Run the pass on the given HLO module. Returns whether it modified the // module. virtual StatusOr Run(HloModule* module) = 0; + + // Run the pass on the given HLO module group. Returns whether it modified the + // module group. Ideally, the module group variant would be named "Run" as + // well, but C++ does not handle overloaded virtual methods well. + virtual StatusOr RunOnModuleGroup(HloModuleGroup* module_group) = 0; +}; + +// Base class for passes which are module-scoped. +class HloModulePass : public HloPassInterface { + public: + // Runs the pass on a module group by iterating through each module in the + // group. + StatusOr RunOnModuleGroup(HloModuleGroup* module_group) override { + bool changed = false; + for (HloModule* module : module_group->modules()) { + TF_ASSIGN_OR_RETURN(bool module_changed, Run(module)); + changed |= module_changed; + } + return changed; + }; +}; + +// Base class for passes which are module-group scoped. These passes cannot run +// on an HLO module. +class HloModuleGroupPass : public HloPassInterface { + public: + StatusOr Run(HloModule* module) override { + return InternalError("Module group pass cannot be run on a module"); + } }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc index df99e131d862a989b191bb3fdb49dff9fb7a3712..5e004ce78ac1fd6da18ab2a54d23ef27e9586cf6 100644 --- a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc +++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc @@ -17,119 +17,140 @@ limitations under the License. #include -#include "absl/strings/str_cat.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" #include "tensorflow/compiler/xla/service/hlo_proto_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/logging.h" namespace xla { -namespace { -using absl::StrAppend; -using absl::StrCat; - -void DumpModuleGraph(const HloModule& module, const string& message) { - hlo_graph_dumper::MaybeDumpHloModule(module, message); - VLOG(3) << "HLO " << message << ":"; - XLA_VLOG_LINES(3, module.ToString()); +template +Status HloPassPipeline::RunInvariantCheckers( + HloT* hlo, absl::string_view after_pass_name) { + for (auto& invariant_checker : invariant_checkers_) { + VLOG(1) << " Invariant checker " << invariant_checker->name(); + StatusOr changed_status = RunHelper(invariant_checker.get(), hlo); + VLOG(1) << " Invariant checker done " << invariant_checker->name(); + if (!changed_status.ok()) { + VLOG(2) << "Failed invariant check:"; + XLA_VLOG_LINES(2, hlo->ToString()); + return Status(changed_status.status().code(), + absl::StrCat(changed_status.status().error_message(), + "\n\nFailed after ", after_pass_name)); + } + TF_RET_CHECK(!changed_status.ValueOrDie()) + << "invariant checkers must not change the graph"; + } + return Status::OK(); } -void DumpModuleProto(const HloModule& module, const string& dump_to, - const string& pipeline_name, const string& pass_name) { - static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED); - static auto* const module_id_to_pass_number = - new tensorflow::gtl::FlatMap(); - - tensorflow::mutex_lock lock(mu); - const int64 pass_number = (*module_id_to_pass_number)[module.unique_id()]++; +template +StatusOr HloPassPipeline::RunPassesInternal( + HloT* hlo, absl::Span passes) { + string last_pass_name = "pipeline-start"; + TF_RETURN_IF_ERROR(RunInvariantCheckers(hlo, last_pass_name)); + bool changed = false; + for (HloPassInterface* pass : passes) { + VLOG(1) << " HLO pass " << pass->name(); + MaybeDumpHlo(*hlo, + /*after_pass_name=*/last_pass_name, + /*before_pass_name=*/pass->name()); + TF_ASSIGN_OR_RETURN(bool pass_changed, RunHelper(pass, hlo)); + changed |= pass_changed; + TF_RETURN_IF_ERROR(RunInvariantCheckers(hlo, pass->name())); + last_pass_name = string(pass->name()); + } + MaybeDumpHlo(*hlo, + /*after_pass_name=*/last_pass_name, + /*before_pass_name=*/"pipeline-end"); + return changed; +} - const string mod_name = SanitizeFileName(tensorflow::strings::Printf( - "module_%04d.%04lld.%s.after_%s", module.unique_id(), pass_number, - pipeline_name.c_str(), pass_name.c_str())); +std::vector HloPassPipeline::GetEnabledPasses( + const DebugOptions& debug_options) { + auto repeated_field = debug_options.xla_disable_hlo_passes(); + absl::flat_hash_set disabled_pass_names(repeated_field.begin(), + repeated_field.end()); + if (!disabled_pass_names.empty()) { + VLOG(1) << "Passes disabled by --xla_disable_hlo_passes: " + << absl::StrJoin(disabled_pass_names, ", "); + } - TF_QCHECK_OK(protobuf_util::DumpProtoToDirectory(MakeHloProto(module), - dump_to, mod_name)); + std::vector enabled_passes; + for (auto& pass : passes_) { + if (disabled_pass_names.count(string(pass->name())) == 0) { + enabled_passes.push_back(pass.get()); + } + } + return enabled_passes; } -} // namespace -StatusOr HloPassPipeline::Run(HloModule* module) { - run_called_ = true; +void HloPassPipeline::MaybeDumpHlo(const HloModule& module, + absl::string_view after_pass_name, + absl::string_view before_pass_name) { + const string& proto_dump_path = + module.config().debug_options().xla_dump_per_pass_hlo_proto_to(); + if (!proto_dump_path.empty()) { + static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED); + static auto* const module_id_to_pass_number = + new absl::flat_hash_map(); + + tensorflow::mutex_lock lock(mu); + const int64 pass_number = (*module_id_to_pass_number)[module.unique_id()]++; + + const string filename = SanitizeFileName( + absl::StrFormat("module_%04d.%04d.%s.after_%s", module.unique_id(), + pass_number, name(), after_pass_name)); + + TF_QCHECK_OK(protobuf_util::DumpProtoToDirectory( + MakeHloProto(module), proto_dump_path, filename)); + } - VLOG(1) << "Running HLO pass pipeline " << name(); + const string message = + StrCat("after ", after_pass_name, ", before ", before_pass_name); + hlo_graph_dumper::MaybeDumpHloModule(module, message); + VLOG(3) << "HLO " << message << ":"; + XLA_VLOG_LINES(3, module.ToString()); +} - auto repeated_field = - module->config().debug_options().xla_disable_hlo_passes(); - tensorflow::gtl::FlatSet disabled_passes(repeated_field.begin(), - repeated_field.end()); - if (!disabled_passes.empty()) { - VLOG(1) << "Passes disabled by --xla_disable_hlo_passes: " - << absl::StrJoin(disabled_passes, ", "); +void HloPassPipeline::MaybeDumpHlo(const HloModuleGroup& module_group, + absl::string_view after_pass_name, + absl::string_view before_pass_name) { + for (const HloModule* module : module_group.modules()) { + MaybeDumpHlo(*module, after_pass_name, before_pass_name); } +} - auto run_invariant_checkers = [this, - module](const string& message) -> Status { - for (auto& invariant_checker : invariant_checkers_) { - VLOG(1) << " Invariant checker " << invariant_checker->name(); - StatusOr changed_status = invariant_checker->Run(module); - VLOG(1) << " Invariant checker done " << invariant_checker->name(); - if (!changed_status.ok()) { - VLOG(2) << "Module failed invariant check:"; - XLA_VLOG_LINES(2, module->ToString()); - return Status(changed_status.status().code(), - StrCat(changed_status.status().error_message(), - "\n\nFailed ", message)); - } - TF_RET_CHECK(!changed_status.ValueOrDie()) - << "invariant checkers must not change the graph"; - } - return Status::OK(); - }; +StatusOr HloPassPipeline::Run(HloModule* module) { + run_called_ = true; - string prefix = std::string(name()) + ": pipeline start"; - bool changed = false; - string message; - TF_RETURN_IF_ERROR( - run_invariant_checkers(StrCat("before running pipeline: ", name()))); - const string xla_dump_per_pass_hlo_proto_to = - module->config().debug_options().xla_dump_per_pass_hlo_proto_to(); - if (!xla_dump_per_pass_hlo_proto_to.empty()) { - DumpModuleProto(*module, xla_dump_per_pass_hlo_proto_to, - std::string(name()), "pipeline_start"); - } + VLOG(1) << "Running HLO pass pipeline on module " << module->name() << ": " + << name(); - for (auto& pass : passes_) { - if (disabled_passes.count(std::string(pass->name())) > 0) { - VLOG(1) << " Skipping HLO pass " << pass->name() - << ", disabled by --xla_disable_hlo_passes"; - continue; - } + return RunPassesInternal(module, + GetEnabledPasses(module->config().debug_options())); +} - VLOG(1) << " HLO pass " << pass->name(); +StatusOr HloPassPipeline::RunOnModuleGroup(HloModuleGroup* module_group) { + run_called_ = true; - // Emit label containing: "after foo-pass, before bar-pass". - message.clear(); - StrAppend(&message, prefix, ", before ", pass->name()); - DumpModuleGraph(*module, message); - - TF_ASSIGN_OR_RETURN(bool changed_this_pass, pass->Run(module)); - TF_RETURN_IF_ERROR( - run_invariant_checkers(StrCat("after running pass: ", pass->name()))); - if (!xla_dump_per_pass_hlo_proto_to.empty()) { - DumpModuleProto(*module, xla_dump_per_pass_hlo_proto_to, - std::string(name()), std::string(pass->name())); - } + VLOG(1) << "Running HLO pass pipeline on module group " + << module_group->name() << ": " << name(); - changed |= changed_this_pass; - prefix.clear(); - StrAppend(&prefix, name(), ": after ", pass->name()); + if (module_group->modules().empty()) { + VLOG(1) << "Module group is empty. Nothing to do."; + return false; } - DumpModuleGraph(*module, prefix + ", pipeline end"); - return changed; + + return RunPassesInternal( + module_group, + GetEnabledPasses(module_group->module(0).config().debug_options())); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.h b/tensorflow/compiler/xla/service/hlo_pass_pipeline.h index 1d41a4dac1d8e2f392be0e4e856ead36a5b71d68..09e7033ea4ed88849d2f3665d04f74f3f388b3f5 100644 --- a/tensorflow/compiler/xla/service/hlo_pass_pipeline.h +++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.h @@ -22,6 +22,7 @@ limitations under the License. #include #include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" #include "tensorflow/compiler/xla/statusor.h" @@ -61,10 +62,45 @@ class HloPassPipeline : public HloPassInterface { return *pass; } - // Run all passes on the given HLO module. StatusOr Run(HloModule* module) override; + StatusOr RunOnModuleGroup(HloModuleGroup* module_group) override; private: + // Returns the set of passes which are enabled. DebugOptions can selectively + // disable passes via --xla_disable_hlo_passes flag. + std::vector GetEnabledPasses( + const DebugOptions& debug_options); + + // Maybe dumps the given module or module group depending on flag values + // contained in DebugOptions of module config. + void MaybeDumpHlo(const HloModuleGroup& module_group, + absl::string_view after_pass_name, + absl::string_view before_pass_name); + void MaybeDumpHlo(const HloModule& module, absl::string_view after_pass_name, + absl::string_view before_pass_name); + + // Runs the invariant checker on the given HLO. HloT can be either HloModule + // or HloModuleGroup. + template + Status RunInvariantCheckers(HloT* hlo, absl::string_view after_pass_name); + + // Helper which runs the given pass on the given HLO. HloT can be either + // HloModule or HloModuleGroup. + template + StatusOr RunPassesInternal(HloT* hlo, + absl::Span passes); + + // Helpers which run the given passes on the given HLO construct. These + // helpers enable templating of the core of the pipeline logic by providing + // HloModule and HloModuleGroup specific methods with the same name. + static StatusOr RunHelper(HloPassInterface* pass, HloModule* module) { + return pass->Run(module); + } + static StatusOr RunHelper(HloPassInterface* pass, + HloModuleGroup* module_group) { + return pass->RunOnModuleGroup(module_group); + } + const string name_; std::vector> passes_; std::vector> invariant_checkers_; diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline_test.cc b/tensorflow/compiler/xla/service/hlo_pass_pipeline_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..ee8cb12b231718e09f6ac0d05d7a6887f4c4d746 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline_test.cc @@ -0,0 +1,259 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h" + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace xla { +namespace { + +class HloPassPipelineTest : public HloVerifiedTestBase { + protected: + StatusOr ParseModuleGroup( + absl::Span hlo_strings) { + HloModuleGroup group(TestName()); + for (const string& hlo_string : hlo_strings) { + TF_ASSIGN_OR_RETURN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + group.push_back(std::move(module)); + } + return std::move(group); + } +}; + +// A module pass which renames instructions named 'foo' to 'bar'. +class FooToBarModulePass : public HloModulePass { + absl::string_view name() const override { return "foo2bar"; } + + StatusOr Run(HloModule* module) override { + bool changed = false; + for (HloComputation* computation : module->computations()) { + for (HloInstruction* instruction : computation->instructions()) { + if (instruction->name() == "foo") { + instruction->SetAndSanitizeName("bar"); + changed = true; + } + } + } + return changed; + } +}; + +// A module group pass which renames instructions named 'baz' to 'qux'. +class BazToQuxModuleGroupPass : public HloModuleGroupPass { + absl::string_view name() const override { return "baz2qux"; } + + StatusOr RunOnModuleGroup(HloModuleGroup* module_group) override { + bool changed = false; + for (HloModule* module : module_group->modules()) { + for (HloComputation* computation : module->computations()) { + for (HloInstruction* instruction : computation->instructions()) { + if (instruction->name() == "baz") { + instruction->SetAndSanitizeName("qux"); + changed = true; + } + } + } + } + return changed; + } +}; + +// An invariant checker pass which returns an error if there exists an +// instruction named 'bar'. +class BarBlowerUpper : public HloModulePass { + absl::string_view name() const override { return "bar-blower-upper"; } + + StatusOr Run(HloModule* module) override { + for (HloComputation* computation : module->computations()) { + for (HloInstruction* instruction : computation->instructions()) { + if (instruction->name() == "bar") { + return InternalError("Module has instruction named bar"); + } + } + } + return false; + } +}; + +TEST_F(HloPassPipelineTest, ModulePassChanged) { + // Test an HLO module pass which changes a module. + const string module_str = R"( +HloModule ModulePassChanged + +ENTRY main { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT foo = f32[] multiply(a, b) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + HloPassPipeline pipeline(TestName()); + pipeline.AddPass(); + + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_EQ(root->name(), "foo"); + TF_ASSERT_OK_AND_ASSIGN(bool changed, pipeline.Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_EQ(root->name(), "bar"); +} + +TEST_F(HloPassPipelineTest, ModulePassUnchanged) { + // Test an HLO module pass which does not change a module. + const string module_str = R"( +HloModule ModulePassUnchanged + +ENTRY main { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT blahblah = f32[] multiply(a, b) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + HloPassPipeline pipeline(TestName()); + pipeline.AddPass(); + + TF_ASSERT_OK_AND_ASSIGN(bool changed, pipeline.Run(module.get())); + EXPECT_FALSE(changed); +} + +TEST_F(HloPassPipelineTest, MixedPipeline) { + // Test a pipeline with both a module pass and a module group pass. + const string module_0_str = R"( +HloModule MixedPipeline.1 + +ENTRY main { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT baz = f32[] multiply(a, b) +} +)"; + const string module_1_str = R"( +HloModule MixedPipeline.0 + +ENTRY main { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT foo = f32[] multiply(a, b) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(HloModuleGroup module_group, + ParseModuleGroup({module_0_str, module_1_str})); + + HloPassPipeline pipeline(TestName()); + pipeline.AddPass(); + pipeline.AddPass(); + + HloInstruction* root0 = + module_group.module(0).entry_computation()->root_instruction(); + HloInstruction* root1 = + module_group.module(1).entry_computation()->root_instruction(); + EXPECT_EQ(root0->name(), "baz"); + EXPECT_EQ(root1->name(), "foo"); + + TF_ASSERT_OK_AND_ASSIGN(bool changed, + pipeline.RunOnModuleGroup(&module_group)); + EXPECT_TRUE(changed); + + EXPECT_EQ(root0->name(), "qux"); + EXPECT_EQ(root1->name(), "bar"); +} + +TEST_F(HloPassPipelineTest, InvariantChecker) { + const string module_str = R"( +HloModule InvariantChecker + +ENTRY main { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT foo = f32[] multiply(a, b) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + { + // Run a pipeline with just the invariant checker. It should not fail + // because there is no 'bar' instruction in the module. + HloPassPipeline pipeline(TestName()); + pipeline.AddInvariantChecker(); + + TF_ASSERT_OK_AND_ASSIGN(bool changed, pipeline.Run(module.get())); + EXPECT_FALSE(changed); + } + + { + // Run a pipeline which renames 'foo' to 'bar' then an invariant checker + // which fails if there is an instruction named 'bar'. + HloPassPipeline pipeline(TestName()); + pipeline.AddInvariantChecker(); + pipeline.AddPass(); + + Status status = pipeline.Run(module.get()).status(); + ASSERT_IS_NOT_OK(status); + EXPECT_THAT(status.error_message(), + ::testing::HasSubstr("Module has instruction named bar")); + EXPECT_THAT(status.error_message(), + ::testing::HasSubstr("Failed after foo2bar")); + } + + { + // Run the invariant-checker only pipeline again. It should fail this time. + HloPassPipeline pipeline(TestName()); + pipeline.AddInvariantChecker(); + + Status status = pipeline.Run(module.get()).status(); + ASSERT_IS_NOT_OK(status); + EXPECT_THAT(status.error_message(), + ::testing::HasSubstr("Module has instruction named bar")); + EXPECT_THAT(status.error_message(), + ::testing::HasSubstr("Failed after pipeline-start")); + } +} + +TEST_F(HloPassPipelineTest, ModuleGroupPassOnModule) { + // Running a module group pass on a module should produce an error. + const string module_str = R"( +HloModule ModuleGroupPassOnModule + +ENTRY main { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT foo = f32[] multiply(a, b) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + HloPassPipeline pipeline(TestName()); + pipeline.AddPass(); + + Status status = pipeline.Run(module.get()).status(); + ASSERT_IS_NOT_OK(status); + EXPECT_THAT( + status.error_message(), + ::testing::HasSubstr("Module group pass cannot be run on a module")); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_proto_util.cc b/tensorflow/compiler/xla/service/hlo_proto_util.cc index 3460679558d185d1e022660d9a1d23176d0d96bf..cf33668f5bfa64a7843efc76e9f6768d18533240 100644 --- a/tensorflow/compiler/xla/service/hlo_proto_util.cc +++ b/tensorflow/compiler/xla/service/hlo_proto_util.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/hlo_proto_util.h" +#include "tensorflow/compiler/xla/service/hlo_verifier.h" #include @@ -23,11 +24,8 @@ namespace xla { HloProto MakeHloProto(const HloModule& module, const BufferAssignment& assignment) { - HloOrderingProto proto_ordering = - assignment.liveness().hlo_ordering().ToProto(); BufferAssignmentProto proto_assignment = assignment.ToProto(); HloProto proto = MakeHloProto(module); - proto.mutable_hlo_ordering()->Swap(&proto_ordering); proto.mutable_buffer_assignment()->Swap(&proto_assignment); return proto; } @@ -39,17 +37,28 @@ HloProto MakeHloProto(const HloModule& module) { return proto; } +StatusOr> CreateModuleFromProto( + const HloModuleProto& proto, const HloModuleConfig& module_config) { + TF_ASSIGN_OR_RETURN(std::unique_ptr module, + HloModule::CreateFromProto(proto, module_config)); + TF_RETURN_IF_ERROR( + HloVerifier(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false) + .Run(module.get()) + .status()); + return std::move(module); +} + StatusOr> EntryComputationParameterShapes( const HloProto& hlo_proto) { if (!hlo_proto.has_hlo_module()) { return NotFound("HloProto missing HloModuleProto."); } - if (!hlo_proto.hlo_module().has_program_shape()) { + if (!hlo_proto.hlo_module().has_host_program_shape()) { return NotFound("HloProto missing program shape."); } std::vector parameter_shapes; - const auto& program_shape = hlo_proto.hlo_module().program_shape(); + const auto& program_shape = hlo_proto.hlo_module().host_program_shape(); for (const Shape& shape : program_shape.parameters()) { parameter_shapes.push_back(&shape); } @@ -60,14 +69,14 @@ StatusOr EntryComputationOutputShape(const HloProto& hlo_proto) { if (!hlo_proto.has_hlo_module()) { return NotFound("HloProto missing HloModuleProto."); } - if (!hlo_proto.hlo_module().has_program_shape()) { + if (!hlo_proto.hlo_module().has_host_program_shape()) { return NotFound("HloProto missing program shape."); } - if (!hlo_proto.hlo_module().program_shape().has_result()) { + if (!hlo_proto.hlo_module().host_program_shape().has_result()) { return NotFound("HloProto missing result in its program shape"); } - return &hlo_proto.hlo_module().program_shape().result(); + return &hlo_proto.hlo_module().host_program_shape().result(); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_proto_util.h b/tensorflow/compiler/xla/service/hlo_proto_util.h index 3d9c375cd5d26f92cf8316f78789daf4fc08c927..1db82dd6fcaa5d7fe7d65894c1021105f0b26266 100644 --- a/tensorflow/compiler/xla/service/hlo_proto_util.h +++ b/tensorflow/compiler/xla/service/hlo_proto_util.h @@ -35,6 +35,12 @@ HloProto MakeHloProto(const HloModule& module, // will not be included in the output. HloProto MakeHloProto(const HloModule& module); +// Create an HLO state from serialized representation. In addition to +// creating the proto with HloModule::CreateFromProto(...) it also +// uses HloVerifier to ensure basic invariants are held. +StatusOr> CreateModuleFromProto( + const HloModuleProto& proto, const HloModuleConfig& module_config); + // Returns the shapes of the parameters of the entry computation. Shape pointers // refer to shapes inside of the given HloProto. StatusOr> EntryComputationParameterShapes( diff --git a/tensorflow/compiler/xla/service/hlo_query.cc b/tensorflow/compiler/xla/service/hlo_query.cc index 2a07b6fcbc243d955e136ccdf097c8155a115845..2d5197be9e6f69f698729e06b7506a5bc6260bcd 100644 --- a/tensorflow/compiler/xla/service/hlo_query.cc +++ b/tensorflow/compiler/xla/service/hlo_query.cc @@ -24,7 +24,7 @@ namespace hlo_query { bool IsConstantR0F32(HloInstruction* instruction, float* out) { if (instruction->opcode() == HloOpcode::kConstant && - ShapeUtil::IsScalarF32(instruction->shape())) { + ShapeUtil::IsScalarWithElementType(instruction->shape(), F32)) { *out = instruction->literal().Get({}); return true; } diff --git a/tensorflow/compiler/xla/service/hlo_reachability.cc b/tensorflow/compiler/xla/service/hlo_reachability.cc index 01b088a957554821e65db7bf9cedf334db49728f..961930f0a888e90f86e4354fa1373a303af8ec2f 100644 --- a/tensorflow/compiler/xla/service/hlo_reachability.cc +++ b/tensorflow/compiler/xla/service/hlo_reachability.cc @@ -18,7 +18,7 @@ limitations under the License. namespace xla { HloReachabilityMap::HloReachabilityMap( - tensorflow::gtl::ArraySlice instructions) + absl::Span instructions) : size_(instructions.size()) { bit_vectors_.reserve(size_); for (const HloInstruction* hlo : instructions) { @@ -29,7 +29,7 @@ HloReachabilityMap::HloReachabilityMap( } bool HloReachabilityMap::SetReachabilityToUnion( - tensorflow::gtl::ArraySlice inputs, + absl::Span inputs, const HloInstruction* instruction) { BitVector& bit_vector = GetBitVector(instruction); tmp_bit_vector_ = bit_vector; @@ -38,13 +38,13 @@ bool HloReachabilityMap::SetReachabilityToUnion( } void HloReachabilityMap::FastSetReachabilityToUnion( - tensorflow::gtl::ArraySlice inputs, + absl::Span inputs, const HloInstruction* instruction) { SetReachabilityToUnionHelper(inputs, instruction, &GetBitVector(instruction)); } void HloReachabilityMap::SetReachabilityToUnionHelper( - tensorflow::gtl::ArraySlice inputs, + absl::Span inputs, const HloInstruction* instruction, BitVector* bit_vector) { // If instruction is part of inputs, don't reset the bit_vector. if (std::find(inputs.begin(), inputs.end(), instruction) == inputs.end()) { diff --git a/tensorflow/compiler/xla/service/hlo_reachability.h b/tensorflow/compiler/xla/service/hlo_reachability.h index 48215d32a8284919cce6beb1663e6a723eefc1c4..5a5f01f8fd647c74217c80ce4a7633b8957e335f 100644 --- a/tensorflow/compiler/xla/service/hlo_reachability.h +++ b/tensorflow/compiler/xla/service/hlo_reachability.h @@ -19,11 +19,11 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -42,7 +42,7 @@ class HloReachabilityMap { // Sets up a graph with no edges and where the nodes correspond to the given // instructions. explicit HloReachabilityMap( - tensorflow::gtl::ArraySlice instructions); + absl::Span instructions); // Set the reachability set of 'instruction' to the union of the reachability // sets of 'inputs'. Upon return, IsReachable(x, instruction) where @@ -54,13 +54,12 @@ class HloReachabilityMap { // vector in the internal graph of this HloReachabilityMap for the given // instruction and does not transitively update any other part of the // adjacency matrix. - bool SetReachabilityToUnion( - tensorflow::gtl::ArraySlice inputs, - const HloInstruction* instruction); + bool SetReachabilityToUnion(absl::Span inputs, + const HloInstruction* instruction); // As above, but faster because it does not check if the reachability changed. void FastSetReachabilityToUnion( - tensorflow::gtl::ArraySlice inputs, + absl::Span inputs, const HloInstruction* instruction); // Sets entry so that IsReachable(a, b) will return true @@ -141,7 +140,7 @@ class HloReachabilityMap { // Helper for SetReachabilityToUnion/FastSetReachabilityToUnion. void SetReachabilityToUnionHelper( - tensorflow::gtl::ArraySlice inputs, + absl::Span inputs, const HloInstruction* instruction, BitVector* bit_vector); // Return the index of the given instruction. The value is used to index into @@ -155,7 +154,7 @@ class HloReachabilityMap { // Dense assignment from HloInstruction* to number. These numbers index // into the bit_vectors_ vector and into the bits within a BitVector. - tensorflow::gtl::FlatMap indices_; + absl::flat_hash_map indices_; // Bitvectors holding the reachability to each instruction. The bit vector for // instruction X includes ones for each instruction which X is reachable from. diff --git a/tensorflow/compiler/xla/service/hlo_reachability_test.cc b/tensorflow/compiler/xla/service/hlo_reachability_test.cc index 585c95972b0e01abc14543205af71b4b0c0bdf3c..d9848cee0bfa904a90aea4626c3ee62c2cbb45b6 100644 --- a/tensorflow/compiler/xla/service/hlo_reachability_test.cc +++ b/tensorflow/compiler/xla/service/hlo_reachability_test.cc @@ -20,13 +20,13 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" namespace xla { namespace { -class HloReachabilityTest : public HloTestBase {}; +class HloReachabilityTest : public HloVerifiedTestBase {}; TEST_F(HloReachabilityTest, Reachability) { // Construct and test a reachability graph of the following form: diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc index 6c6e7c6fecea6447aea8c6b01f30867a50f38e22..49e46ecd00ee4370f3e93746348373b79febed3d 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc @@ -20,27 +20,28 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/service/buffer_value.h" -#include "tensorflow/compiler/xla/service/copy_insertion.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_dce.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" -#include "tensorflow/compiler/xla/service/hlo_scheduling.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" namespace xla { @@ -76,7 +77,7 @@ bool IsRematerializable(const HloInstruction* instruction) { // cache before, and eventually calling the IsRematerializable() API. bool CanBeRematerialized( const HloInstruction* instruction, - tensorflow::gtl::FlatMap* remat_able) { + absl::flat_hash_map* remat_able) { auto it = remat_able->find(instruction); if (it != remat_able->end()) { return it->second; @@ -202,8 +203,8 @@ class InstructionList { // On object construction this ordinal is precisely the instruction's index // in the list. Later, instructions inserted via InsertBefore receive // duplicate values. However, monotonicity is preserved. - void InsertBeforeInstructions( - Item* to_insert, tensorflow::gtl::ArraySlice before_instructions) { + void InsertBeforeInstructions(Item* to_insert, + absl::Span before_instructions) { VLOG(3) << "InsertBeforeInstructions: " << to_insert->instruction->name() << " before {" << absl::StrJoin(before_instructions, ", ", @@ -269,7 +270,7 @@ class InstructionList { Item* first_; // Item for each instruction. - tensorflow::gtl::FlatMap item_map_; + absl::flat_hash_map item_map_; }; // Return the items which use the given LogicalBuffer. Sets @@ -504,7 +505,7 @@ MemoryUsageTracker::MemoryUsageTracker( PointsToSet::BufferSet live_out_set = points_to_analysis.GetPointsToSet(computation_->root_instruction()) .CreateFlattenedSet(); - tensorflow::gtl::FlatMap + absl::flat_hash_map logical_buffer_to_buffer_id; for (auto* item = instruction_list_.first(); item != nullptr; @@ -855,7 +856,7 @@ int64 RematerializationCost(const HloInstruction* instruction, Item* PickRematerializationCandidate( const MemoryUsageTracker& memory_tracker, const InstructionList& instruction_list, int64 memory_limit_bytes, - tensorflow::gtl::FlatMap* remat_able) { + absl::flat_hash_map* remat_able) { Item* best_item = nullptr; int64 best_cost = 0; @@ -962,8 +963,7 @@ StatusOr HloRematerialization::CalledComputationsMemoryUsage( } StatusOr HloRematerialization::RematerializeComputation( - HloComputation* computation, - SequentialHloOrdering::HloModuleSequence* sequence, + HloComputation* computation, HloSchedule* schedule, int64 memory_limit_bytes) { VLOG(1) << "Rematerializing computation " << computation->name() << " with limit " << HumanReadableNumBytes(memory_limit_bytes); @@ -971,7 +971,8 @@ StatusOr HloRematerialization::RematerializeComputation( << HumanReadableNumBytes(computation_peak_memory_.at(computation)); CHECK(!ContainsKey(rematerialized_computations_, computation)); - InstructionList instruction_list(sequence->at(computation)); + InstructionList instruction_list( + schedule->sequence(computation).instructions()); MemoryUsageTracker memory_tracker(computation, size_function_, *points_to_analysis_, instruction_list); bool changed = false; @@ -981,10 +982,10 @@ StatusOr HloRematerialization::RematerializeComputation( // rematerialization is essentially a move). If the next rematerialization of // the instruction is also a move then the rematerialization is added to the // blacklist. - tensorflow::gtl::FlatSet remat_move_instructions; + absl::flat_hash_set remat_move_instructions; // The map from instructions to their rematerializable status. - tensorflow::gtl::FlatMap remat_able; + absl::flat_hash_map remat_able; // The peak memory of the computation at any point in the instruction // sequence. @@ -1145,7 +1146,7 @@ StatusOr HloRematerialization::RematerializeComputation( 0, memory_limit_bytes - memory_tracker.memory_usage()); TF_ASSIGN_OR_RETURN( bool subcomputation_changed, - RematerializeComputation(called_computation, sequence, + RematerializeComputation(called_computation, schedule, subcomputation_memory_limit_bytes)); changed |= subcomputation_changed; } @@ -1179,12 +1180,12 @@ StatusOr HloRematerialization::RematerializeComputation( computation_peak_memory_.at(computation) = peak_memory; // Update order to include rematerialized instructions. - auto& dst = sequence->at(computation); - dst.clear(); + HloInstructionSequence& sequence = schedule->GetOrCreateSequence(computation); + sequence.clear(); for (auto* item = instruction_list.first(); item != nullptr; item = instruction_list.next(item)) { const HloInstruction* instruction = item->instruction; - dst.push_back(instruction); + sequence.push_back(instruction); } rematerialized_computations_.insert(computation); @@ -1194,59 +1195,18 @@ StatusOr HloRematerialization::RematerializeComputation( return changed; } -StatusOr HloRematerialization::Run( - HloModule* module, SequentialHloOrdering::HloModuleSequence* sequence, - int64 memory_limit_bytes, RematerializationSizes* sizes, - CopyInsertion* copy_insertion) { - // The sequence is constructed entirely by this method. - TF_RET_CHECK(sequence->empty()); - +StatusOr HloRematerialization::Run(HloModule* module) { VLOG(1) << "HloRematerialization() with memory limit of " - << HumanReadableNumBytes(memory_limit_bytes); + << HumanReadableNumBytes(memory_limit_bytes_); XLA_VLOG_LINES(3, "Before HloRematerialization:\n" + module->ToString()); - // Create initial sequence of HLO instructions. - TF_ASSIGN_OR_RETURN(*sequence, ScheduleComputationsInModule( - *module, - [this](const BufferValue& buffer) { - return size_function_(buffer.shape()); - }, - scheduler_algorithm_)); - if (copy_insertion) { - // We run a separate pass of copy elision here because the sequential - // ordering from the HLO schedule allows for more copies to be eliminated. - // TODO(b/80249101): Instead of a separate copy elision pass, use the - // ordering from the HLO schedule directly for copy insertion. - - // First create a copy of the schedule which contains HloInstruction unique - // ids instead of HloInstruction*. This is necessary for updating the - // schedule below. - // TODO(b/113175018): Remove this when the HLO schedule is self-contained - // and can update itself. - tensorflow::gtl::FlatMap> - id_sequence = ComputeIdSchedule(*sequence); - - SequentialHloOrdering ordering(module, *sequence); - TF_RETURN_IF_ERROR( - copy_insertion->RemoveUnnecessaryCopies(ordering, module)); - - // RemoveUnnecessaryCopies only considers interference when determining - // whether it is legal to remove a copy. However, copies in the graph may be - // necessary for other reason such as preventing a constant from being live - // out of the graph. So run AddSpecialCaseCopies to re-insert these copies. - // TODO(b/80249101): Break copy insertion into several passes and run each - // one once in the regular HLO pipeline. - TF_RETURN_IF_ERROR(copy_insertion->AddSpecialCaseCopies(module)); - - // The passes above can add and remove copies, update the schedule to - // account for these transformations. Newly added instructions will be - // placed ASAP in the schedule. - TF_RETURN_IF_ERROR(UpdateSchedule(*module, id_sequence, sequence)); - - TF_DCHECK_OK(copy_insertion->VerifyNoLiveRangeInterference( - SequentialHloOrdering(module, *sequence), module)); - } + // Initialize pass object state. + computation_peak_memory_.clear(); + rematerialized_computations_.clear(); + instructions_rematerialized_ = 0; + net_instructions_added_ = 0; + TF_RET_CHECK(module->has_schedule()); TF_ASSIGN_OR_RETURN(points_to_analysis_, TuplePointsToAnalysis::Run(module)); // Adjust memory limit to account for the output of the entry @@ -1255,14 +1215,14 @@ StatusOr HloRematerialization::Run( // by the caller. int64 module_output_size = 0; ShapeUtil::ForEachSubshape( - module->entry_computation()->root_instruction()->shape(), + module->result_shape(), [&module_output_size, this](const Shape& subshape, const ShapeIndex& /*index*/) { module_output_size += size_function_(subshape); }); const int64 adjusted_memory_limit_bytes = - memory_limit_bytes - module_output_size; + memory_limit_bytes_ - module_output_size; VLOG(1) << "Adjusted memory limit accounting for output (" << HumanReadableNumBytes(module_output_size) << "): " << HumanReadableNumBytes(adjusted_memory_limit_bytes); @@ -1271,12 +1231,14 @@ StatusOr HloRematerialization::Run( // sequential context. call_graph_ = CallGraph::Build(module); TF_RETURN_IF_ERROR(call_graph_->VisitNodes( - [this, sequence](const CallGraphNode& node) -> Status { + [this, module](const CallGraphNode& node) -> Status { if (node.context() == CallContext::kSequential) { TF_ASSIGN_OR_RETURN( computation_peak_memory_[node.computation()], ComputePeakMemory(node.computation(), - sequence->at(node.computation()))); + module->schedule() + .sequence(node.computation()) + .instructions())); } return Status::OK(); }, @@ -1294,9 +1256,10 @@ StatusOr HloRematerialization::Run( // Subcomputations called by the entry computation will also be // rematerialized. - TF_ASSIGN_OR_RETURN(bool changed, RematerializeComputation( - module->entry_computation(), sequence, - adjusted_memory_limit_bytes)); + TF_ASSIGN_OR_RETURN( + bool changed, + RematerializeComputation(module->entry_computation(), &module->schedule(), + adjusted_memory_limit_bytes)); // Rematerialization can introduce dead code. This occurs if all uses of an // instruction are replaced with rematerializations of the instruction. @@ -1305,30 +1268,7 @@ StatusOr HloRematerialization::Run( // After DCE, the module sequence may include instructions which no longer // exist. - for (const auto* computation : module->MakeNonfusionComputations()) { - if (sequence->at(computation).size() != computation->instruction_count()) { - // A size mismatch between the computation instruction count and the size - // of the ordering of instructions can only be caused by DCE. Rebuild the - // order by removing the deleted instructions from the order. - tensorflow::gtl::FlatSet instruction_set; - for (const auto& instruction : computation->instructions()) { - instruction_set.insert(instruction); - } - // Move the old order into a temporary vector, then build new order - // inplace. - std::vector& order = sequence->at(computation); - std::vector old_order; - using std::swap; - swap(order, old_order); - std::copy_if(old_order.begin(), old_order.end(), - std::back_inserter(order), - [&instruction_set](const HloInstruction* instruction) { - return ContainsKey(instruction_set, instruction); - }); - TF_RET_CHECK(sequence->at(computation).size() == - computation->instruction_count()); - } - } + TF_RETURN_IF_ERROR(module->schedule().Update()); VLOG(1) << "Rematerialized " << instructions_rematerialized_ << " instructions in module " << module->name() << "; " << net_instructions_added_ << " net instructions added"; @@ -1345,34 +1285,22 @@ StatusOr HloRematerialization::Run( << HumanReadableNumBytes(reduced_peak_memory) << " (" << reduced_peak_memory << " bytes)"; - if (sizes != nullptr) { - sizes->before_bytes = before_peak_memory; - sizes->after_bytes = current_peak_memory; + if (sizes_ != nullptr) { + sizes_->before_bytes = before_peak_memory; + sizes_->after_bytes = current_peak_memory; } XLA_VLOG_LINES(3, "After HloRematerialization:\n" + module->ToString()); - if (current_peak_memory > memory_limit_bytes) { - LOG(WARNING) << tensorflow::strings::Printf( - "Can't reduce memory use below %s (%lld bytes) by rematerialization; " - "only reduced to %s (%lld bytes)", - HumanReadableNumBytes(memory_limit_bytes).c_str(), memory_limit_bytes, - HumanReadableNumBytes(current_peak_memory).c_str(), - current_peak_memory); + if (current_peak_memory > memory_limit_bytes_) { + LOG(WARNING) << absl::StrFormat( + "Can't reduce memory use below %s (%d bytes) by rematerialization; " + "only reduced to %s (%d bytes)", + HumanReadableNumBytes(memory_limit_bytes_), memory_limit_bytes_, + HumanReadableNumBytes(current_peak_memory), current_peak_memory); } return changed; } -/* static */ StatusOr HloRematerialization::RematerializeAndSchedule( - const HloRematerialization::ShapeSizeFunction& size_function, - int64 memory_limit_bytes, HloModule* hlo_module, - MemorySchedulerAlgorithm scheduler_algorithm, - SequentialHloOrdering::HloModuleSequence* sequence, - RematerializationSizes* sizes, CopyInsertion* copy_insertion) { - HloRematerialization remat(scheduler_algorithm, size_function); - return remat.Run(hlo_module, sequence, memory_limit_bytes, sizes, - copy_insertion); -} - } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.h b/tensorflow/compiler/xla/service/hlo_rematerialization.h index 2ec004350ad88ff31ece90ec419d90a55b965166..70d83c04f07ca7fd0139f586869e8fe688f958f4 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.h +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.h @@ -15,18 +15,27 @@ #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REMATERIALIZATION_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REMATERIALIZATION_H_ +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/service/buffer_liveness.h" #include "tensorflow/compiler/xla/service/call_graph.h" -#include "tensorflow/compiler/xla/service/copy_insertion.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h" #include "tensorflow/compiler/xla/service/hlo_module.h" -#include "tensorflow/compiler/xla/service/hlo_scheduling.h" +#include "tensorflow/compiler/xla/service/hlo_schedule.h" #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" namespace xla { -class HloRematerialization { +// HLO pass which rematerializes instructions to reduce peak memory use, where +// memory use is defined as the total size of all live HLO instruction +// values. Parameters and constants are included in memory use estimates. +// +// CSE will undo the effects of this optimization and should not be run after +// this pass. In general, this pass should be run very late, immediately before +// code generation. +class HloRematerialization : public HloModulePass { public: using ShapeSizeFunction = std::function; @@ -37,10 +46,7 @@ class HloRematerialization { int64 after_bytes; }; - // Rematerialize HLO instructions in the given module to reduce peak memory - // use below memory_limit_bytes where memory use is defined as the total size - // of all live HLO instruction values. Parameters and constants are included - // in memory use estimates. Method parameters: + // Constructor parameters: // // size_function: Function which returns the size in bytes of the top-level // buffer of the given shape. @@ -48,60 +54,34 @@ class HloRematerialization { // memory_limit_bytes: The threshold number of bytes to reduce memory use to // via rematerialization. // - // hlo_module: HLO module to rematerialize instructions in. - // - // sequence: Should point to an empty HloModuleSequence. Upon return - // contains the HLO instruction order which was used for - // rematerialization. This is the order in which HLO instructions should - // be emitted to minimize memory use. - // - // sizes: Optional outparam that indicates the peak memory usage of the HLO - // module before/after rematerialization. - // - // copy_insertion: If non-null, run copy elision after scheduling. This - // pass is used to eliminate copies that were inserted by copy insertion - // before HLO scheduling. - // - // TODO(b/80249101): Remove the 'run_copy_elision' parameter when copy - // insertion is integrated with HLO scheduling. - // - // Returns whether any instructions were rematerialized. If memory use is - // already below the given limit then no instructions are rematerialized and - // false is returned. - // - // CSE will undo the effects of this optimization and should not be run after - // this pass. In general, this pass should be run very late immediately before - // code generation. - static StatusOr RematerializeAndSchedule( - const ShapeSizeFunction& size_function, int64 memory_limit_bytes, - HloModule* hlo_module, MemorySchedulerAlgorithm scheduler_algorithm, - SequentialHloOrdering::HloModuleSequence* sequence, - RematerializationSizes* sizes, CopyInsertion* copy_insertion = nullptr); - - protected: - HloRematerialization(MemorySchedulerAlgorithm scheduler_algorithm, - const ShapeSizeFunction& size_function) - : scheduler_algorithm_(scheduler_algorithm), - size_function_(size_function) {} + // sizes: Pointer to data structure which records the peak memory usage of + // the HLO module before/after rematerialization. Value are set during + // Run(). Can be nullptr. + HloRematerialization(const ShapeSizeFunction& size_function, + int64 memory_limit_bytes, RematerializationSizes* sizes) + : size_function_(size_function), + memory_limit_bytes_(memory_limit_bytes), + sizes_(sizes) {} ~HloRematerialization() {} + absl::string_view name() const override { return "rematerialization"; } + // Runs rematerialization on the given module. Returns whether the module was - // changed. memory_limit is the target maximum peak memory usage by the - // module. sequence should be an empty HloModuleSequence. Upon return sequence - // contains the memory-minimizing order in which to emit the HLO instructions. - StatusOr Run(HloModule* module, - SequentialHloOrdering::HloModuleSequence* sequence, - int64 memory_limit, RematerializationSizes* sizes, - CopyInsertion* copy_insertion); + // changed. Requires that the module has a schedule set + // (HloModule::has_schedule() is true) before running. Returns whether any + // instructions were rematerialized. If memory use is already below the limit + // specified in the constructor then no instructions are rematerialized and + // false is returned. + StatusOr Run(HloModule* module) override; + protected: // Rematerializes instructions within the given computation. 'order' is the // order in which the computation's instructions will be emitted in the // backend. Rematerialized instructions will be added to the HLO computation // and inserted into 'order'. - StatusOr RematerializeComputation( - HloComputation* computation, - SequentialHloOrdering::HloModuleSequence* sequence, - int64 computation_memory_limit); + StatusOr RematerializeComputation(HloComputation* computation, + HloSchedule* schedule, + int64 memory_limit_bytes); // Computes and returns the peak memory used by the given computation. The // peak memory is the maximum total size of all live HLO instruction values at @@ -122,6 +102,14 @@ class HloRematerialization { // Function which computes the size of the top-level buffer of a shape. const ShapeSizeFunction size_function_; + // The threshold number of bytes to reduce memory use to via + // rematerialization. + const int64 memory_limit_bytes_; + + // Pointer to data structure which records the peak memory usage of the HLO + // module before/after rematerialization + RematerializationSizes* sizes_; + // Call graph of the hlo_module. std::unique_ptr call_graph_; @@ -129,14 +117,13 @@ class HloRematerialization { // computations called from sequential context // (CallContext::kSequential). These values are updated as rematerialization // occurs. - tensorflow::gtl::FlatMap - computation_peak_memory_; + absl::flat_hash_map computation_peak_memory_; std::unique_ptr points_to_analysis_; // Set of computations which have had rematerialization // applied. Rematerialization is only applied once per computation. - tensorflow::gtl::FlatSet rematerialized_computations_; + absl::flat_hash_set rematerialized_computations_; // Count of the total instructions rematerialized. int64 instructions_rematerialized_ = 0; diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc index ac8c97d380953764b66135ad1c5fcee0d481c004..f7e82fb1f88e856305f6f481a451d4cd64ba4acf 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc @@ -24,7 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -36,7 +36,7 @@ namespace op = xla::testing::opcode_matchers; using ::testing::_; -class HloRematerializationTest : public HloTestBase { +class HloRematerializationTest : public HloVerifiedTestBase { protected: // Creates and returns a computation which can benefit from // rematerialization. The computation looks like: @@ -141,13 +141,16 @@ class HloRematerializationTest : public HloTestBase { return ShapeUtil::ByteSizeOf(shape, sizeof(void*)); } - StatusOr RunHloRematerialization( - int64 memory_limit_bytes, HloModule* module, - SequentialHloOrdering::HloModuleSequence* sequence) { + StatusOr RunHloRematerialization(int64 memory_limit_bytes, + HloModule* module) { TF_EXPECT_OK(verifier().Run(module).status()); - return HloRematerialization::RematerializeAndSchedule( - ByteSizeOf, memory_limit_bytes, module, DefaultMemoryScheduler, - sequence, /*sizes=*/nullptr); + HloMemoryScheduler scheduler( + [](const BufferValue& buffer) { return ByteSizeOf(buffer.shape()); }, + DefaultMemoryScheduler); + TF_EXPECT_OK(scheduler.Run(module).status()); + HloRematerialization remat(ByteSizeOf, memory_limit_bytes, + /*sizes=*/nullptr); + return remat.Run(module); } // Various shapes used in the canned computations. @@ -170,12 +173,11 @@ TEST_F(HloRematerializationTest, SingleComputation) { const HloInstruction* concat = slice->operand(0); const HloInstruction* bcast = concat->operand(0); - SequentialHloOrdering::HloModuleSequence sequence; // Computation requires 16KB without rematerialization, but uses only 12KB // with rematerialization so pick a memory limit between these values (14KB). - TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( - /*memory_limit_bytes=*/14 * 1024, - module.get(), &sequence)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + RunHloRematerialization( + /*memory_limit_bytes=*/14 * 1024, module)); EXPECT_TRUE(changed); // Root should not have changed. @@ -187,9 +189,13 @@ TEST_F(HloRematerializationTest, SingleComputation) { // The rematerialized broadcast should be immediate before the concat in the // sequence. - EXPECT_EQ(sequence.at(computation)[computation->instruction_count() - 2], + EXPECT_EQ(module->schedule() + .sequence(computation) + .instructions()[computation->instruction_count() - 2], concat); - EXPECT_EQ(sequence.at(computation)[computation->instruction_count() - 3], + EXPECT_EQ(module->schedule() + .sequence(computation) + .instructions()[computation->instruction_count() - 3], remat_bcast); } @@ -203,10 +209,9 @@ TEST_F(HloRematerializationTest, SingleComputationNoRematerialization) { EXPECT_EQ(computation->instruction_count(), 8); - SequentialHloOrdering::HloModuleSequence sequence; - TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( - /*memory_limit_bytes=*/20 * 1024, - module.get(), &sequence)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + RunHloRematerialization( + /*memory_limit_bytes=*/20 * 1024, module)); // No instructions should have been materialized. EXPECT_FALSE(changed); @@ -242,10 +247,9 @@ TEST_F(HloRematerializationTest, RematerializeAroundWhile) { // The body computation uses 16KB and the entry computation uses 2KB at the // while so the peak memory use of the module is 18KB. Set the memory limit a // bit lower (17KB) to force rematerialization of the entry computation. - SequentialHloOrdering::HloModuleSequence sequence; - TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( - /*memory_limit_bytes=*/17 * 1024, - module.get(), &sequence)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + RunHloRematerialization( + /*memory_limit_bytes=*/17 * 1024, module)); EXPECT_TRUE(changed); // Only the entry computation should have a rematerialized instruction added. @@ -276,10 +280,9 @@ TEST_F(HloRematerializationTest, RematerializeEntryAndWhileBody) { EXPECT_EQ(entry_computation->instruction_count(), 7); EXPECT_EQ(body_computation->instruction_count(), 8); - SequentialHloOrdering::HloModuleSequence sequence; - TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( - /*memory_limit_bytes=*/15 * 1024, - module.get(), &sequence)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + RunHloRematerialization( + /*memory_limit_bytes=*/15 * 1024, module)); EXPECT_TRUE(changed); // Both computations should have rematerialized instructions added. @@ -316,10 +319,9 @@ TEST_F(HloRematerializationTest, RematerializeNestedComputations) { // If all computations are maximally rematerialized then peak memory usage is // ~12K so pick something slightly larger. - SequentialHloOrdering::HloModuleSequence sequence; - TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( - /*memory_limit_bytes=*/13 * 1024, - module.get(), &sequence)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + RunHloRematerialization( + /*memory_limit_bytes=*/13 * 1024, module)); EXPECT_TRUE(changed); // All computations should have rematerialized instructions added. @@ -382,14 +384,13 @@ TEST_F(HloRematerializationTest, RngNotRematerialized) { ASSERT_EQ(count_rngs(entry_computation), 1); const int64 original_instruction_count = entry_computation->instruction_count(); - SequentialHloOrdering::HloModuleSequence sequence; // Pick a memory limit some where between 24KB (initial peak memory including // parameter and output) and 20KB (peak memory possible with // rematerialization). TF_ASSERT_OK_AND_ASSIGN( - bool changed, RunHloRematerialization( - /*memory_limit_bytes=*/4 * ByteSizeOf(vec1024_shape_), - module.get(), &sequence)); + bool changed, + RunHloRematerialization( + /*memory_limit_bytes=*/4 * ByteSizeOf(vec1024_shape_), module)); EXPECT_TRUE(changed); // The rng should not have been rematerialized. EXPECT_EQ(count_rngs(entry_computation), 1); @@ -476,13 +477,12 @@ TEST_F(HloRematerializationTest, InstructionRematerializedMultipleTimes) { EXPECT_EQ(add_3->operand(0), bcast); EXPECT_EQ(add_4->operand(0), bcast); - SequentialHloOrdering::HloModuleSequence sequence; // Pick a memory limit some where between 24KB (initial peak memory including // parameter and output) and 20KB (peak memory possible with // rematerialization). - TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( - /*memory_limit_bytes=*/22 * 1024, - module.get(), &sequence)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + RunHloRematerialization( + /*memory_limit_bytes=*/22 * 1024, module)); EXPECT_TRUE(changed); // The broadcast should have been rematerialized 3 times. @@ -571,13 +571,12 @@ TEST_P(IndirectUseTest, IndirectUseNotRematerialized) { EXPECT_EQ(entry_computation->instruction_count(), 8); - SequentialHloOrdering::HloModuleSequence sequence; // Pick a memory limit some where between 24KB (initial peak memory including // parameter and output) and 20KB (peak memory possible with // rematerialization). - TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( - /*memory_limit_bytes=*/22 * 1024, - module.get(), &sequence)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + RunHloRematerialization( + /*memory_limit_bytes=*/22 * 1024, module)); // Rematerialization should only occur if the rematerializable instruction has // no indirect uses. if (indirectly_used) { diff --git a/tensorflow/compiler/xla/service/hlo_runner.cc b/tensorflow/compiler/xla/service/hlo_runner.cc index 7bd8a4a544b21a35f20eeed493f7e0528a7e87dd..fa7f216321988137dcf9104a324f5f7789869aa5 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.cc +++ b/tensorflow/compiler/xla/service/hlo_runner.cc @@ -106,7 +106,7 @@ StatusOr HloRunner::TransferLiteralToDevice( } StatusOr> HloRunner::TransferLiteralsToDevice( - const tensorflow::gtl::ArraySlice literals) { + const absl::Span literals) { std::vector buffers; for (const Literal* literal : literals) { CHECK(literal != nullptr); @@ -118,16 +118,16 @@ StatusOr> HloRunner::TransferLiteralsToDevice( } StatusOr> HloRunner::TransferLiteralsToDevice( - const tensorflow::gtl::ArraySlice> literals) { + const absl::Span literals) { std::vector literal_pointers; literal_pointers.reserve(literals.size()); for (const auto& literal : literals) { - literal_pointers.push_back(literal.get()); + literal_pointers.push_back(&literal); } return TransferLiteralsToDevice(literal_pointers); } -StatusOr> HloRunner::TransferLiteralFromDevice( +StatusOr HloRunner::TransferLiteralFromDevice( const ShapedBuffer& buffer) { TF_ASSIGN_OR_RETURN( auto stream, backend().BorrowStream(backend().default_stream_executor())); @@ -135,10 +135,10 @@ StatusOr> HloRunner::TransferLiteralFromDevice( buffer); } -StatusOr> HloRunner::Execute( +StatusOr HloRunner::Execute( std::unique_ptr module, - const tensorflow::gtl::ArraySlice arguments, - bool run_hlo_passes, ExecutionProfile* profile) { + const absl::Span arguments, bool run_hlo_passes, + ExecutionProfile* profile) { TF_ASSIGN_OR_RETURN(std::vector argument_buffers, TransferLiteralsToDevice(arguments)); TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result, @@ -150,15 +150,15 @@ StatusOr> HloRunner::Execute( return TransferLiteralFromDevice(result); } -StatusOr> HloRunner::Execute( - std::unique_ptr module, - const tensorflow::gtl::ArraySlice> arguments, - bool run_hlo_passes, ExecutionProfile* profile) { +StatusOr HloRunner::Execute(std::unique_ptr module, + const absl::Span arguments, + bool run_hlo_passes, + ExecutionProfile* profile) { // Construct a vector of plain pointers for the arguments. std::vector argument_pointers; argument_pointers.reserve(arguments.size()); for (const auto& argument : arguments) { - argument_pointers.push_back(argument.get()); + argument_pointers.push_back(&argument); } return Execute( /*module=*/std::move(module), @@ -169,8 +169,8 @@ StatusOr> HloRunner::Execute( StatusOr HloRunner::ExecuteWithDeviceBuffers( std::unique_ptr module, - const tensorflow::gtl::ArraySlice arguments, - bool run_hlo_passes, ExecutionProfile* profile) { + const absl::Span arguments, bool run_hlo_passes, + ExecutionProfile* profile) { // Get service run options. se::Stream stream(backend().default_stream_executor()); stream.Init(); @@ -190,8 +190,8 @@ StatusOr HloRunner::ExecuteWithDeviceBuffers( StatusOr HloRunner::ExecuteWithDeviceBuffers( std::unique_ptr module, - const tensorflow::gtl::ArraySlice arguments, - bool run_hlo_passes, ExecutionProfile* profile) { + const absl::Span arguments, bool run_hlo_passes, + ExecutionProfile* profile) { std::vector argument_pointers; argument_pointers.reserve(arguments.size()); for (const auto& argument : arguments) { @@ -204,7 +204,7 @@ StatusOr HloRunner::ExecuteWithDeviceBuffers( /*profile=*/profile); } -StatusOr>> HloRunner::ExecuteReplicated( +StatusOr> HloRunner::ExecuteReplicated( std::unique_ptr module, const ReplicatedExecuteOptions& options) { TF_ASSIGN_OR_RETURN( @@ -226,8 +226,7 @@ StatusOr>> HloRunner::ExecuteReplicated( // no arguments. std::vector argument_buffer_ptrs( options.num_replicas * options.arguments.size() + 1); - std::vector> - argument_buffer_slices; + std::vector> argument_buffer_slices; int64 index = 0; for (int64 i = 0; i < options.num_replicas; ++i) { int64 device = device_assignment(i, 0); @@ -291,9 +290,9 @@ StatusOr>> HloRunner::ExecuteReplicated( VLOG(1) << "Starting outfeed on device " << device; for (int64 step = 1; options.infeed_steps < 0 || step <= options.infeed_steps; ++step) { - auto literal = absl::make_unique(); + Literal literal; TF_CHECK_OK(backend().transfer_manager()->TransferLiteralFromOutfeed( - executor, options.outfeed_shape, literal.get())); + executor, options.outfeed_shape, &literal)); if (options.outfeed_values != nullptr) { options.outfeed_values->push_back(std::move(literal)); } @@ -311,10 +310,10 @@ StatusOr>> HloRunner::ExecuteReplicated( argument_buffer_slices)); LOG(INFO) << "Replicated execution terminated"; - std::vector> exec_results; + std::vector exec_results; for (int64 i = 0; i < options.num_replicas; ++i) { TF_RETURN_IF_ERROR(streams[i]->BlockHostUntilDone()); - TF_ASSIGN_OR_RETURN(std::unique_ptr literal, + TF_ASSIGN_OR_RETURN(Literal literal, backend().transfer_manager()->TransferLiteralFromDevice( streams[i].get(), results[i])); exec_results.push_back(std::move(literal)); diff --git a/tensorflow/compiler/xla/service/hlo_runner.h b/tensorflow/compiler/xla/service/hlo_runner.h index cfc519063e837cb961c4c4fb1efe611a7fe273ba..2e934bf66ae43ea412f242030b874dddb6d3722d 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.h +++ b/tensorflow/compiler/xla/service/hlo_runner.h @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/computation_placer.h" @@ -33,7 +34,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" namespace xla { @@ -72,7 +72,7 @@ class HloRunner { // A pointer to a vector where the outfeed values will be stored. If // nullptr, the values will be read and discarded. - std::vector>* outfeed_values = nullptr; + std::vector* outfeed_values = nullptr; // Whether the HLO passes should be run on the input module. Usually // saved modules are coming from after the HLO pass pipeline, so triggering @@ -104,43 +104,42 @@ class HloRunner { // Transfers data between the host and device. StatusOr TransferLiteralToDevice(const Literal& literal); StatusOr> TransferLiteralsToDevice( - const tensorflow::gtl::ArraySlice literals); + const absl::Span literals); StatusOr> TransferLiteralsToDevice( - const tensorflow::gtl::ArraySlice> literals); - StatusOr> TransferLiteralFromDevice( - const ShapedBuffer& buffer); + const absl::Span literals); + StatusOr TransferLiteralFromDevice(const ShapedBuffer& buffer); // Executes the given module with given literals as input and returns the // result as a Literal. // // If run_hlo_passes is false, the module will be executed without Hlo // optimization. - StatusOr> Execute( - std::unique_ptr module, - const tensorflow::gtl::ArraySlice arguments, - bool run_hlo_passes = true, ExecutionProfile* profile = nullptr); + StatusOr Execute(std::unique_ptr module, + const absl::Span arguments, + bool run_hlo_passes = true, + ExecutionProfile* profile = nullptr); - StatusOr> Execute( - std::unique_ptr module, - const tensorflow::gtl::ArraySlice> arguments, - bool run_hlo_passes = true, ExecutionProfile* profile = nullptr); + StatusOr Execute(std::unique_ptr module, + const absl::Span arguments, + bool run_hlo_passes = true, + ExecutionProfile* profile = nullptr); // As Execute(), but accepts and returns device buffers instead of host // buffers. StatusOr ExecuteWithDeviceBuffers( std::unique_ptr module, - const tensorflow::gtl::ArraySlice arguments, + const absl::Span arguments, bool run_hlo_passes = true, ExecutionProfile* profile = nullptr); StatusOr ExecuteWithDeviceBuffers( std::unique_ptr module, - const tensorflow::gtl::ArraySlice arguments, + const absl::Span arguments, bool run_hlo_passes = true, ExecutionProfile* profile = nullptr); // Executes a given HLO module into a set of replicas, and returns a map // with the replica number as key, and the corresponding returned literal as // value. - StatusOr>> ExecuteReplicated( + StatusOr> ExecuteReplicated( std::unique_ptr module, const ReplicatedExecuteOptions& options); diff --git a/tensorflow/compiler/xla/service/hlo_schedule.cc b/tensorflow/compiler/xla/service/hlo_schedule.cc new file mode 100644 index 0000000000000000000000000000000000000000..9972eb20774550817143cb27dd94667364cf68ec --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_schedule.cc @@ -0,0 +1,343 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_schedule.h" + +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/gtl/map_util.h" + +namespace xla { + +/* static */ StatusOr HloSchedule::CreateFromProto( + const HloModule* module, const HloScheduleProto& proto) { + absl::flat_hash_map id_to_computation; + for (const HloComputation* computation : module->computations()) { + id_to_computation[computation->unique_id()] = computation; + } + + HloSchedule schedule(module); + for (const auto& id_sequence : proto.sequences()) { + int64 computation_id = id_sequence.first; + + auto comp_it = id_to_computation.find(computation_id); + TF_RET_CHECK(comp_it != id_to_computation.end()) + << "No computation exists in HLO module with id " << computation_id; + const HloComputation* computation = comp_it->second; + + absl::flat_hash_map id_to_instruction; + for (const HloInstruction* instruction : computation->instructions()) { + id_to_instruction[instruction->unique_id()] = instruction; + } + + HloInstructionSequence& sequence = + schedule.GetOrCreateSequence(computation); + for (const int64 instruction_id : id_sequence.second.instruction_ids()) { + auto instr_it = id_to_instruction.find(instruction_id); + TF_RET_CHECK(instr_it != id_to_instruction.end()) + << "No instruction exists in HLO computation " << computation->name() + << " with id " << instruction_id; + sequence.push_back(instr_it->second); + } + } + TF_RETURN_IF_ERROR(schedule.Verify()); + return std::move(schedule); +} + +StatusOr HloSchedule::ToProto() const { + TF_RETURN_IF_ERROR(Verify()); + HloScheduleProto proto; + for (const auto& id_sequence : sequences_) { + int64 computation_id = id_sequence.first; + const HloInstructionSequence& sequence = id_sequence.second; + HloScheduleProto::InstructionSequence& proto_sequence = + (*proto.mutable_sequences())[computation_id]; + proto_sequence.mutable_instruction_ids()->Reserve(sequence.size()); + for (const int64 id : sequence.ids()) { + proto_sequence.add_instruction_ids(id); + } + } + return std::move(proto); +} + +void HloSchedule::set_sequence( + const HloComputation* computation, + absl::Span sequence) { + set_sequence(computation, HloInstructionSequence(sequence)); +} + +void HloSchedule::set_sequence(const HloComputation* computation, + HloInstructionSequence sequence) { + CHECK(computation->parent() == module_); + sequences_[computation->unique_id()] = std::move(sequence); +} + +HloInstructionSequence& HloSchedule::GetOrCreateSequence( + const HloComputation* computation) { + auto it = sequences_.find(computation->unique_id()); + if (it == sequences_.end()) { + // No sequence found for computation. Create and return an empty one. + CHECK(computation->parent() == module_); + return sequences_[computation->unique_id()]; + } else { + return it->second; + } +} + +const HloInstructionSequence& HloSchedule::sequence( + const HloComputation* computation) const { + return sequences_.at(computation->unique_id()); +} + +Status HloSchedule::UpdateComputationSchedule( + const HloComputation* computation) { + // Map from unique ID to HloInstruction pointer for instructions in the + // computation. + absl::flat_hash_map id_to_instruction; + for (const HloInstruction* instruction : computation->instructions()) { + InsertOrDie(&id_to_instruction, instruction->unique_id(), instruction); + } + + // Set of all HloInstructions in the schedule. + absl::flat_hash_set ids_in_schedule; + for (int id : sequences_.at(computation->unique_id()).ids()) { + InsertOrDie(&ids_in_schedule, id); + } + + // Map from HloInstruction X to newly added instructions (instruction is in + // computation, but not in schedule) which use X. If an instruction is not in + // the map, then it has no users which are newly added instructions. + absl::flat_hash_map> + new_instruction_uses; + + // For each newly added instruction, this is the count of the instruction's + // operands that have not yet been scheduled. When this value reaches zero, + // then the instruction may be placed in the schedule. + absl::flat_hash_map unscheduled_operand_count; + + // Create a worklist of newly added instructions which are ready to be added + // to the schedule. Initialize worklist with those that have zero operands. + std::queue worklist; + + for (const HloInstruction* instruction : computation->instructions()) { + if (ids_in_schedule.count(instruction->unique_id()) == 0) { + // This is a newly added instruction which is not in the schedule. + if (instruction->operands().empty()) { + worklist.push(instruction); + } else { + for (const HloInstruction* operand : instruction->operands()) { + new_instruction_uses[operand].push_back(instruction); + } + unscheduled_operand_count[instruction] = instruction->operand_count(); + } + } + } + + // Update the schedule with the newly added instructions, and remove any + // instructions no longer in the graph. + HloInstructionSequence new_sequence; + + // Lambda which schedules all instructions on the worklist. + auto schedule_worklist = [&]() { + while (!worklist.empty()) { + const HloInstruction* instruction = worklist.front(); + worklist.pop(); + new_sequence.push_back(instruction); + std::vector* new_users = + tensorflow::gtl::FindOrNull(new_instruction_uses, instruction); + if (new_users != nullptr) { + // This just-scheduled instruction has users which are newly added to + // the module. Update the number of unscheduled operands and push the + // newly added instruction to the worklist if it is ready to + // schedule. + for (const HloInstruction* new_user : *new_users) { + unscheduled_operand_count.at(new_user)--; + CHECK_GE(unscheduled_operand_count.at(new_user), 0); + if (unscheduled_operand_count.at(new_user) == 0) { + worklist.push(new_user); + } + } + } + } + }; + + schedule_worklist(); + for (int id : sequences_.at(computation->unique_id()).ids()) { + auto it = id_to_instruction.find(id); + if (it == id_to_instruction.end()) { + // This instruction in the schedule is no longer in the module. Do not add + // it to the new schedule. + continue; + } + worklist.push(it->second); + schedule_worklist(); + } + + set_sequence(computation, std::move(new_sequence)); + return Status::OK(); +} + +Status HloSchedule::Update() { + // The schedule must contain a sequence for every non-fusion computation in + // the module, but can have sequences for computations which no longer exist + // (these are removed). + std::vector nonfusion_computations = + module_->MakeNonfusionComputations(); + for (const HloComputation* computation : nonfusion_computations) { + TF_RET_CHECK(sequences_.count(computation->unique_id()) == 1) + << "Computation " << computation->name() << " not in HloSchedule."; + } + if (sequences_.size() > nonfusion_computations.size()) { + // Schedule contains some computations which have been removed from the + // HloModule. Remove them from the schedule as well. + absl::flat_hash_set nonfusion_computations_ids; + for (const HloComputation* computation : nonfusion_computations) { + nonfusion_computations_ids.insert(computation->unique_id()); + } + for (auto it = sequences_.begin(); it != sequences_.end();) { + if (nonfusion_computations_ids.count(it->first) == 0) { + sequences_.erase(it++); + } else { + ++it; + } + } + } + CHECK_EQ(sequences_.size(), nonfusion_computations.size()); + + for (const HloComputation* computation : nonfusion_computations) { + TF_RETURN_IF_ERROR(UpdateComputationSchedule(computation)); + } + + TF_RETURN_IF_ERROR(Verify()); + return Status::OK(); +} + +Status HloSchedule::Verify() const { + VLOG(2) << "VerifySchedule()"; + XLA_VLOG_LINES(3, module_->ToString()); + XLA_VLOG_LINES(2, ToString()); + + // Verify schedule contains exactly the same set of non-fusion computations as + // module currently does. + std::vector nonfusion_computations = + module_->MakeNonfusionComputations(); + TF_RET_CHECK(nonfusion_computations.size() == sequences_.size()) + << "Schedule has " << sequences_.size() << " sequences, but module has " + << nonfusion_computations.size() << " non-fusion computations"; + for (const HloComputation* computation : nonfusion_computations) { + TF_RET_CHECK(sequences_.count(computation->unique_id()) == 1) + << "Computation " << computation->name() + << " missing from HLO schedule."; + } + + // For each computation verify the set of instructions is the same and that + // each dependency and control edge is honored. + for (const HloComputation* computation : nonfusion_computations) { + absl::flat_hash_map instruction_position; + int pos = 0; + for (const HloInstruction* instruction : + sequence(computation).instructions()) { + TF_RET_CHECK(instruction_position.insert({instruction, pos}).second) + << "Instruction " << instruction->name() + << " appears more than once in the schedule"; + pos++; + } + + TF_RET_CHECK(instruction_position.size() == + computation->instruction_count()); + for (const HloInstruction* instruction : computation->instructions()) { + TF_RET_CHECK(instruction_position.count(instruction) == 1) + << "Instruction " << instruction->name() << " is not in schedule"; + } + + for (const HloInstruction* instruction : computation->instructions()) { + for (const HloInstruction* operand : instruction->operands()) { + TF_RET_CHECK(instruction_position.at(operand) < + instruction_position.at(instruction)) + << "Instruction " << instruction->name() + << " is not scheduled after its operand " << operand->name(); + } + + for (const HloInstruction* pred : instruction->control_predecessors()) { + TF_RET_CHECK(instruction_position.at(pred) < + instruction_position.at(instruction)) + << "Instruction " << instruction->name() + << " is not scheduled after its control predecessor " + << pred->name(); + } + } + } + + return Status::OK(); +} + +namespace { + +// Returns the computation in the given module with the given unique ID. Returns +// nullptr if no such computation exists. +const HloComputation* IdToComputation(const HloModule* module, int64 id) { + for (const HloComputation* computation : module->computations()) { + if (computation->unique_id() == id) { + return computation; + } + } + return nullptr; +} + +} // namespace + +string HloSchedule::ToString() const { + std::vector pieces; + + pieces.push_back("HloSchedule"); + for (const auto& id_sequence : sequences_) { + const HloComputation* computation = + IdToComputation(module_, id_sequence.first); + if (computation == nullptr) { + // The computation is not in the module and may have been deleted so it is + // not safe to dereference any HLO pointers. Just use the HLO unique ids + // stored in this object. + pieces.push_back( + absl::StrFormat("computation with id %d (no longer in HLO module):", + id_sequence.first)); + for (int id : id_sequence.second.ids()) { + pieces.push_back(absl::StrCat(" ", id)); + } + } else { + pieces.push_back(absl::StrFormat("computation %s:", computation->name())); + for (const HloInstruction* instruction : + id_sequence.second.instructions()) { + pieces.push_back(absl::StrCat(" ", instruction->name())); + } + } + } + return absl::StrJoin(pieces, "\n"); +} + +std::ostream& operator<<(std::ostream& out, const HloSchedule& schedule) { + out << schedule.ToString(); + return out; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_schedule.h b/tensorflow/compiler/xla/service/hlo_schedule.h new file mode 100644 index 0000000000000000000000000000000000000000..0a714101ee587aa847fa674bbde5586287c51f33 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_schedule.h @@ -0,0 +1,158 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULE_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULE_H_ + +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/types/span.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_schedule.h" +#include "tensorflow/compiler/xla/status.h" + +namespace xla { + +class HloModule; + +// Class representing a sequence of HLO instructions such as the sequential +// execution order of an HLO computation. +class HloInstructionSequence { + public: + HloInstructionSequence() = default; + explicit HloInstructionSequence( + absl::Span instructions) { + for (const HloInstruction* instruction : instructions) { + push_back(instruction); + } + } + + // Adds the instruction to the end of the sequence. + void push_back(const HloInstruction* instruction) { + instruction_sequence_.push_back(instruction); + id_sequence_.push_back(instruction->unique_id()); + } + + // Clears the sequence of all instructions. + void clear() { + instruction_sequence_.clear(); + id_sequence_.clear(); + } + + int64 size() const { return instruction_sequence_.size(); } + + // Returns the sequence of HLO instructions. + const std::vector& instructions() const { + return instruction_sequence_; + } + + // Returns the unique IDs of the instructions in the sequence (in order). + const std::vector& ids() const { return id_sequence_; } + + private: + // The sequence as HloInstructions. + std::vector instruction_sequence_; + + // The sequence of HLO instructions, represented by their unique IDs. The + // sequence is stored as both HloInstructions and unique IDs because the + // sequence may be referenced after transformations to the HLO graph and HLO + // pointers can be invalidated or recycled in this process (see + // HloSchedule::Update). + std::vector id_sequence_; +}; + +// A class representing a sequential schedule of instructions for an HLO +// module. A complete HLO schedule contains an instruction sequence for every +// non-fusion computation in the HLO module. +class HloSchedule { + public: + explicit HloSchedule(const HloModule* module) : module_(module) {} + + // (De)Serialize an HloSchedule to/from a HloScheduleProto. + static StatusOr CreateFromProto(const HloModule* module, + const HloScheduleProto& proto); + StatusOr ToProto() const; + + // Returns a reference to the sequence for the given computation. + const HloInstructionSequence& sequence( + const HloComputation* computation) const; + + // Returns the sequence for the given computation. An empty sequence is + // created if none exists for the computation. + HloInstructionSequence& GetOrCreateSequence( + const HloComputation* computation); + + // Sets the sequence for the given computation to the given sequence. + void set_sequence(const HloComputation* computation, + absl::Span sequence); + void set_sequence(const HloComputation* computation, + HloInstructionSequence sequence); + + // Returns a map from HloComputation unique ID to instruction sequence. The + // map contains all sequences in the schedule. + const absl::flat_hash_map& sequences() const { + return sequences_; + } + + // Returns true if the schedule has a sequence for the given computation. + bool is_computation_scheduled(const HloComputation* computation) const { + return sequences_.count(computation->unique_id()) == 1; + } + + // Updates the schedule such that it is (again) a valid schedule for the + // module. This is used to update a schedule after the HLO module has been + // transformed in some way. In general, the only transformations to the module + // for which a schedule can be updated is the addition or removal of + // instructions and removal of computations. Updating the schedule after new + // dependencies between existing instructions in the module is not supported + // and may result in an error status returned. + // + // Instructions in the module which also exist in the given schedule will + // remain in the same order in the updated schedule. Instructions which exist + // in the module but not in the given schedule will be placed as early as + // possible in the updated schedule. + Status Update(); + + // Verifies that the given schedule is valid for the given module. + // Specifically, the schedule contains exactly the instructions in the + // non-fusion computations in the module and every dependency in the module is + // satisfied in the schedule. + Status Verify() const; + + string ToString() const; + + bool empty() const { return sequences_.empty(); } + + const HloModule* module() const { return module_; } + + private: + // Updates the instruction sequence for the given computation. + Status UpdateComputationSchedule(const HloComputation* computation); + + const HloModule* module_; + + // A map from computation unique ID to instruction sequence. Unique IDs are + // used rather than HloComputation pointers because HLO pointers are not + // unique across HLO transformations because pointers may be recycled. + absl::flat_hash_map sequences_; +}; + +std::ostream& operator<<(std::ostream& out, const HloSchedule& schedule); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULE_H_ diff --git a/tensorflow/compiler/xla/service/hlo_schedule_test.cc b/tensorflow/compiler/xla/service/hlo_schedule_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..1424569ac1f62e4b965876141f1eb40be4f15bea --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_schedule_test.cc @@ -0,0 +1,341 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_schedule.h" + +#include +#include + +#include "absl/algorithm/container.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_dce.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_ordering.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace xla { +namespace { + +class HloScheduleTest : public HloTestBase {}; + +TEST_F(HloScheduleTest, UpdateScheduleUnchangedModule) { + // Updating the schedule of an unchanged HLO module should not affect the + // schedule at all. + const string module_str = R"( +HloModule UpdateScheduleUnchanged + +ENTRY main { + a = f32[] parameter(0) + b = f32[] parameter(1) + c = f32[] constant(42.0) + sum = f32[] add(a, b) + neg = f32[] negate(c) + ROOT root = f32[] multiply(sum, neg) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(module_str)); + TF_ASSERT_OK_AND_ASSIGN( + HloSchedule schedule, + ScheduleModule(*module, [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape()); + })); + const std::vector& entry_schedule = + schedule.sequence(module->entry_computation()).instructions(); + + EXPECT_EQ(entry_schedule.size(), 6); + + TF_ASSERT_OK(schedule.Update()); + TF_ASSERT_OK(schedule.Verify()); + + EXPECT_EQ(entry_schedule, + schedule.sequence(module->entry_computation()).instructions()); +} + +TEST_F(HloScheduleTest, UpdateScheduleWithNewInstructions) { + // Add some additional instructions to a module and verify the schedule can be + // updated. + const string module_str = R"( +HloModule UpdateScheduleWithNewInstructions + +ENTRY main { + a = f32[] parameter(0) + b = f32[] parameter(1) + c = f32[] constant(42.0) + sum = f32[] add(a, b) + neg = f32[] negate(c) + ROOT root = f32[] multiply(sum, neg) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(module_str)); + TF_ASSERT_OK_AND_ASSIGN( + HloSchedule schedule, + ScheduleModule(*module, [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape()); + })); + + HloComputation* entry = module->entry_computation(); + const Shape shape = entry->root_instruction()->shape(); + HloInstruction* constant = entry->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); + HloInstruction* sub = entry->AddInstruction(HloInstruction::CreateBinary( + shape, HloOpcode::kSubtract, constant, entry->root_instruction())); + entry->set_root_instruction(sub); + + auto in_schedule = [&](const HloInstruction* hlo) { + return absl::c_linear_search(schedule.sequence(entry).instructions(), hlo); + }; + + EXPECT_EQ(schedule.sequence(entry).size(), 6); + EXPECT_FALSE(in_schedule(constant)); + EXPECT_FALSE(in_schedule(sub)); + + ASSERT_IS_NOT_OK(schedule.Verify()); + TF_ASSERT_OK(schedule.Update()); + TF_ASSERT_OK(schedule.Verify()); + + EXPECT_EQ(schedule.sequence(entry).size(), 8); + EXPECT_TRUE(in_schedule(constant)); + EXPECT_TRUE(in_schedule(sub)); +} + +TEST_F(HloScheduleTest, UpdateScheduleWithAddedAndDeletedInstruction) { + // Add and delete some instructions from a module and verify that the schedule + // can be updated successfully. + const string module_str = R"( +HloModule UpdateScheduleWithAddedAndDeletedInstruction + +ENTRY main { + a = f32[] parameter(0) + b = f32[] parameter(1) + c = f32[] constant(42.0) + sum = f32[] add(a, b) + neg = f32[] negate(c) + ROOT root = f32[] multiply(sum, neg) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(module_str)); + TF_ASSERT_OK_AND_ASSIGN( + HloSchedule schedule, + ScheduleModule(*module, [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape()); + })); + + // Set the entry root to some expression containing just a parameter and a + // constant. + HloComputation* entry = module->entry_computation(); + HloInstruction* constant = entry->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); + HloInstruction* new_root = entry->AddInstruction( + HloInstruction::CreateBinary(constant->shape(), HloOpcode::kSubtract, + constant, entry->parameter_instruction(0))); + entry->set_root_instruction(new_root); + + // DCE should remove everything but the parameters and the newly added code. + HloDCE dce; + TF_ASSERT_OK(dce.Run(module.get()).status()); + + EXPECT_EQ(schedule.sequence(entry).size(), 6); + + ASSERT_IS_NOT_OK(schedule.Verify()); + TF_ASSERT_OK(schedule.Update()); + TF_ASSERT_OK(schedule.Verify()); + + EXPECT_EQ(schedule.sequence(entry).size(), 4); +} + +TEST_F(HloScheduleTest, UpdateScheduleWithCompletelyReplacedModule) { + // Completely replace a module with an entirely new set of instructions and + // verify that the schedule can be updated successfully. + const string module_str = R"( +HloModule UpdateScheduleWithCompletelyReplacedModule + +ENTRY main { + a = f32[] constant(42.0) + b = f32[] constant(123.0) + ROOT sum = f32[] add(a, b) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(module_str)); + TF_ASSERT_OK_AND_ASSIGN( + HloSchedule schedule, + ScheduleModule(*module, [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape()); + })); + + // Replace the entry computation with the negation of a constant. + HloComputation* entry = module->entry_computation(); + HloInstruction* constant = entry->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction* new_root = entry->AddInstruction(HloInstruction::CreateUnary( + constant->shape(), HloOpcode::kNegate, constant)); + entry->set_root_instruction(new_root); + + // DCE the old instructions. + HloDCE dce; + TF_ASSERT_OK(dce.Run(module.get()).status()); + + EXPECT_EQ(schedule.sequence(entry).size(), 3); + + ASSERT_IS_NOT_OK(schedule.Verify()); + TF_ASSERT_OK(schedule.Update()); + TF_ASSERT_OK(schedule.Verify()); + + EXPECT_EQ(schedule.sequence(entry).size(), 2); +} + +TEST_F(HloScheduleTest, UpdateScheduleWithMultipleComputations) { + // Create changes to more than one computation in an HLO module and verify + // that the schedule can be updated. + const string module_str = R"( +HloModule UpdateScheduleWithMultipleComputations + +%Body (param.1: (s32[], token[])) -> (s32[], token[]) { + %param.1 = (s32[], token[]) parameter(0) + %get-tuple-element.1 = s32[] get-tuple-element((s32[], token[]) %param.1), index=0 + %constant.1 = s32[] constant(1) + %add = s32[] add(s32[] %get-tuple-element.1, s32[] %constant.1) + %get-tuple-element.2 = token[] get-tuple-element((s32[], token[]) %param.1), index=1 + %after-all = token[] after-all(token[] %get-tuple-element.2) + ROOT %tuple = (s32[], token[]) tuple(s32[] %add, token[] %after-all) +} + +%Cond (param: (s32[], token[])) -> pred[] { + %param = (s32[], token[]) parameter(0) + %get-tuple-element = s32[] get-tuple-element((s32[], token[]) %param), index=0 + %constant = s32[] constant(42) + ROOT %less-than = pred[] less-than(s32[] %get-tuple-element, s32[] %constant) +} + +ENTRY %WhileLoop () -> s32[] { + %zero = s32[] constant(0) + %init_token = token[] after-all() + %init_tuple = (s32[], token[]) tuple(s32[] %zero, token[] %init_token) + %while = (s32[], token[]) while((s32[], token[]) %init_tuple), condition=%Cond, body=%Body + ROOT %root = s32[] get-tuple-element((s32[], token[]) %while), index=0 +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(module_str)); + TF_ASSERT_OK_AND_ASSIGN( + HloSchedule schedule, + ScheduleModule(*module, [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape(), + /*pointer_size=*/sizeof(void*)); + })); + + const HloInstruction* xla_while = + module->entry_computation()->root_instruction()->operand(0); + HloComputation* body = xla_while->while_body(); + HloComputation* cond = xla_while->while_condition(); + + // Negate the root of the cond. + cond->set_root_instruction(cond->AddInstruction( + HloInstruction::CreateUnary(ShapeUtil::MakeShape(PRED, {}), + HloOpcode::kNot, cond->root_instruction()))); + + // Replace the body with a computation which just passes through its + // parameter. + body->set_root_instruction(body->parameter_instruction(0)); + + // DCE the dead code in the body. + HloDCE dce; + TF_ASSERT_OK(dce.Run(module.get()).status()); + + EXPECT_EQ(schedule.sequence(body).size(), 7); + EXPECT_EQ(schedule.sequence(cond).size(), 4); + + ASSERT_IS_NOT_OK(schedule.Verify()); + TF_ASSERT_OK(schedule.Update()); + TF_ASSERT_OK(schedule.Verify()); + + EXPECT_EQ(schedule.sequence(body).size(), 1); + EXPECT_EQ(schedule.sequence(cond).size(), 5); +} + +TEST_F(HloScheduleTest, UpdateScheduleComputationRemoved) { + // Remove computations from a module and verify the schedule can be updated. + const string module_str = R"( +HloModule UpdateScheduleWithMultipleComputations + +%Body (param.1: (s32[], token[])) -> (s32[], token[]) { + %param.1 = (s32[], token[]) parameter(0) + %get-tuple-element.1 = s32[] get-tuple-element((s32[], token[]) %param.1), index=0 + %constant.1 = s32[] constant(1) + %add = s32[] add(s32[] %get-tuple-element.1, s32[] %constant.1) + %get-tuple-element.2 = token[] get-tuple-element((s32[], token[]) %param.1), index=1 + %after-all = token[] after-all(token[] %get-tuple-element.2) + ROOT %tuple = (s32[], token[]) tuple(s32[] %add, token[] %after-all) +} + +%Cond (param: (s32[], token[])) -> pred[] { + %param = (s32[], token[]) parameter(0) + %get-tuple-element = s32[] get-tuple-element((s32[], token[]) %param), index=0 + %constant = s32[] constant(42) + ROOT %less-than = pred[] less-than(s32[] %get-tuple-element, s32[] %constant) +} + +ENTRY %WhileLoop () -> s32[] { + %zero = s32[] constant(0) + %init_token = token[] after-all() + %init_tuple = (s32[], token[]) tuple(s32[] %zero, token[] %init_token) + %while = (s32[], token[]) while((s32[], token[]) %init_tuple), condition=%Cond, body=%Body + ROOT %root = s32[] get-tuple-element((s32[], token[]) %while), index=0 +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(module_str)); + TF_ASSERT_OK_AND_ASSIGN( + HloSchedule schedule, + ScheduleModule(*module, [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape(), + /*pointer_size=*/sizeof(void*)); + })); + + HloInstruction* xla_while = + module->entry_computation()->root_instruction()->mutable_operand(0); + HloInstruction* init = xla_while->mutable_operand(0); + + // Replace the while with its init value. The conditional and body + // computations should then be dead. + TF_ASSERT_OK(xla_while->ReplaceAllUsesWith(init)); + + // DCE the dead code in the body. + HloDCE dce; + ASSERT_EQ(module->computation_count(), 3); + TF_ASSERT_OK(dce.Run(module.get()).status()); + ASSERT_EQ(module->computation_count(), 1); + + ASSERT_IS_NOT_OK(schedule.Verify()); + TF_ASSERT_OK(schedule.Update()); + TF_ASSERT_OK(schedule.Verify()); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.h b/tensorflow/compiler/xla/service/hlo_scheduling.h deleted file mode 100644 index d06b8d9a5cdef82380bd68ae0991a3957db80f48..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/hlo_scheduling.h +++ /dev/null @@ -1,127 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULING_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULING_H_ - -#include - -#include "tensorflow/compiler/xla/service/hlo_instruction.h" -#include "tensorflow/compiler/xla/service/hlo_module.h" -#include "tensorflow/compiler/xla/service/hlo_ordering.h" -#include "tensorflow/compiler/xla/service/logical_buffer.h" -#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" -#include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/compiler/xla/types.h" - -namespace xla { - -// A memory scheduler computes an execution sequence for the HLO instructions in -// 'computation' that minimizes peak memory, given a points-to analysis result -// that describes buffer aliasing, together with a target-specific size function -// that maps a tensor's logical size to its padded size. -typedef std::function>( - const HloComputation&, const TuplePointsToAnalysis&, - const LogicalBuffer::SizeFunction&, - const tensorflow::gtl::FlatMap&)> - MemorySchedulerAlgorithm; - -// List scheduler -StatusOr> ListMemoryScheduler( - const HloComputation& computation, - const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_function, - const tensorflow::gtl::FlatMap& - memory_by_computation); - -// DFS-order scheduler -StatusOr> DFSMemoryScheduler( - const HloComputation& computation, - const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_function, - const tensorflow::gtl::FlatMap& - memory_by_computation); - -// Naive Post Order scheduler -StatusOr> PostOrderMemoryScheduler( - const HloComputation& computation, - const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_function, - const tensorflow::gtl::FlatMap& - memory_by_computation); - -// The default scheduling algorithm. Runs both the list scheduler -// and the DFS scheduler, and chooses whichever returns a lower min-memory, -// not accounting for fragmentation. -StatusOr> DefaultMemoryScheduler( - const HloComputation& computation, - const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_function, - const tensorflow::gtl::FlatMap& - memory_by_computation); - -// Returns an HloModuleSequence which seeks to minimize the memory required for -// the computation. size_function is the function returning the number of bytes -// required for a LogicalBuffer. -StatusOr ScheduleComputationsInModule( - const HloModule& module, const LogicalBuffer::SizeFunction& size_function, - const MemorySchedulerAlgorithm& algorithm = {}); - -// Computes the schedule for a single computation. -// Currently only used by the GPU backend. -StatusOr> ScheduleOneComputation( - const HloComputation& computation, - const LogicalBuffer::SizeFunction& size_function); - -// Transforms the given schedule such that it is (again) a valid schedule for -// the module. This is used to update a schedule after the HLO module has been -// transformed in some way. In general, the only transformations to the module -// for which a schedule can be updated is the addition or removal of -// instructions to/from the module. Updating the schedule after new dependencies -// between existing instructions in the module is not supported and may result -// in an error status returned. -// -// Instructions in the module which also exist in the given schedule will remain -// in the same order in the updated schedule. Instructions which exist in the -// module but not in the given schedule will be placed as early as possible in -// the updated schedule. -// -// 'id_sequence' is a mirror of the given schedule 'sequence' but with -// HloInstruction ids rather than HloInstruction pointers. This should be -// constructed using ComputeIdSchedule below after the schedule is constructed -// but before the HLO module is transformed. -Status UpdateSchedule( - const HloModule& module, - const tensorflow::gtl::FlatMap>& - id_sequence, - SequentialHloOrdering::HloModuleSequence* sequence); - -// Constructs a copy of the given schedule but with HloInstruction unique ids -// rather than HloInstruction pointers. This is necessary for updating a -// schedule as HloInstruction points in the schedule may become invalid if -// instructions are removed from the module. Used by UpdateSchedule above.. -// TODO(b/113175018): Remove this function when HLO schedule is its own class. -tensorflow::gtl::FlatMap> -ComputeIdSchedule(const SequentialHloOrdering::HloModuleSequence& sequence); - -// Verifies that the given schedule is valid for the given module. Specifically, -// the schedule contains exactly the instructions in the module and every -// dependency in the module is satisfied in the schedule. -Status VerifySchedule(const HloModule& module, - const SequentialHloOrdering::HloModuleSequence& sequence); - -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULING_H_ diff --git a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc b/tensorflow/compiler/xla/service/hlo_scheduling_test.cc deleted file mode 100644 index 930801288a0ea0fa7fd75dd38610430ae7010b5a..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc +++ /dev/null @@ -1,667 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/hlo_scheduling.h" - -#include -#include - -#include "tensorflow/compiler/xla/service/heap_simulator.h" -#include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/compiler/xla/service/hlo_dce.h" -#include "tensorflow/compiler/xla/service/hlo_instruction.h" -#include "tensorflow/compiler/xla/service/hlo_opcode.h" -#include "tensorflow/compiler/xla/service/hlo_ordering.h" -#include "tensorflow/compiler/xla/service/hlo_parser.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/core/status_test_util.h" - -namespace xla { -namespace { - -class HloSchedulingTest : public HloTestBase {}; - -TEST_F(HloSchedulingTest, LastUseScheduledFirst) { - // Tests scheduling of the following HLO code: - // - // %ab = abs(%param) - // %exp = exp(%param) - // %add = add(%ab, %exp) - // %negate = negate(%exp) - // %sub = subtract(%add, %negate) - // - // %add should be scheduled before %negate because %add is the last (and only) - // use of %ab. Scheduling %add first then frees up %ab's buffer. - const Shape vec = ShapeUtil::MakeShape(xla::F32, {42}); - auto builder = HloComputation::Builder(TestName()); - auto param = - builder.AddInstruction(HloInstruction::CreateParameter(0, vec, "param")); - auto ab = builder.AddInstruction( - HloInstruction::CreateUnary(vec, HloOpcode::kAbs, param)); - auto exp = builder.AddInstruction( - HloInstruction::CreateUnary(vec, HloOpcode::kExp, param)); - - auto add = builder.AddInstruction( - HloInstruction::CreateBinary(vec, HloOpcode::kAdd, ab, exp)); - auto negate = builder.AddInstruction( - HloInstruction::CreateUnary(vec, HloOpcode::kNegate, exp)); - auto sub = builder.AddInstruction( - HloInstruction::CreateBinary(vec, HloOpcode::kSubtract, add, negate)); - - auto module = CreateNewModule(); - module->AddEntryComputation(builder.Build()); - - TF_ASSERT_OK_AND_ASSIGN( - SequentialHloOrdering::HloModuleSequence sequence, - ScheduleComputationsInModule(*module, [](const BufferValue& buffer) { - return ShapeUtil::ByteSizeOf(buffer.shape()); - })); - // Verify that all instructions are in the sequence. - EXPECT_EQ(module->entry_computation()->instruction_count(), - sequence.at(module->entry_computation()).size()); - - // The first instruction should be the parameter and the last the root "sub". - EXPECT_EQ(param, sequence.at(module->entry_computation()).front()); - EXPECT_EQ(sub, sequence.at(module->entry_computation()).back()); - - SequentialHloOrdering ordering(module.get(), sequence); - EXPECT_TRUE(ordering.ExecutesBefore(add, negate)); -} - -TEST_F(HloSchedulingTest, ListSchedulerHandlesAliasing) { - const char* module_str = R"( -HloModule test_aliasing_module - -ENTRY root { - param = s32[1000] parameter(0) - p0 = s32[1000] copy(param) - p1 = s32[1000] copy(param) - t = (s32[1000], s32[1000]) tuple(p0, p1) - a = s32[1000] get-tuple-element(t), index=0 - b = s32[1000] get-tuple-element(t), index=1 - c = s32[1000] add(a, b) - d = s32[1000] add(c, b) - e = s32[1000] add(c, c) - f = s32[1000] add(e, e) - ROOT result = (s32[1000], s32[1000], s32[1000]) tuple(d, e, f) -})"; - - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseHloString(module_str)); - - auto size_fn = [](const BufferValue& buffer) { - return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8); - }; - TF_ASSERT_OK_AND_ASSIGN( - SequentialHloOrdering::HloModuleSequence sequence, - ScheduleComputationsInModule(*module, size_fn, ListMemoryScheduler)); - // Verify that all instructions are in the sequence. - EXPECT_EQ(module->entry_computation()->instruction_count(), - sequence.at(module->entry_computation()).size()); - - std::unordered_map instructions_by_name; - for (const HloInstruction* instruction : - sequence.at(module->entry_computation())) { - instructions_by_name[instruction->name()] = instruction; - } - - // The first instruction should be the parameter and the last the root. - EXPECT_EQ(instructions_by_name.at("param"), - sequence.at(module->entry_computation()).front()); - EXPECT_EQ(instructions_by_name.at("result"), - sequence.at(module->entry_computation()).back()); - - // Instructions "d" and "e" will both be schedulable at the same time, but - // instruction "d" allows us to free the buffer of "p1", so the list scheduler - // should prefer it. - SequentialHloOrdering ordering(module.get(), sequence); - EXPECT_TRUE(ordering.ExecutesBefore(instructions_by_name.at("d"), - instructions_by_name.at("e"))); -} - -TEST_F(HloSchedulingTest, ListAccountsForSubcomputations) { - // %WhileCond (cond_param: f32[4]) -> pred[] { - // %cond_param = f32[4]{0} parameter(0) - // %constant = f32[1,4]{1,0} constant(f32[1,4] { { 0, 0, 0, 0 } }) - // ROOT %not-equal-to = pred[] not-equal-to( - // f32[4]{0} %cond_param, f32[1,4]{1,0} %constant) - // } - // %WhileBody (body_param: f32[4]) -> f32[4] { - // %body_param = f32[4]{0} parameter(0) - // %constant.1 = f32[1,4]{1,0} constant(f32[1,4] { { 1, 1, 1, 1 } }) - // ROOT %subtract = f32[4]{0} subtract( - // f32[4]{0} %body_param, f32[1,4]{1,0} %constant.1) - // } - // %ListAccountsForSubcomputations () -> f32[2,4] { - // %constant.3 = f32[2,4]{1,0} constant( - // f32[2,4] { { 1, 2, 3, 4 }, { 1, 2, 3, 4 } }) - // %transpose = f32[2,4]{1,0} transpose( - // f32[2,4]{1,0} %constant.3), dimensions={0,1} - // %constant.2 = f32[1,4]{1,0} constant(f32[1,4] { { 1, 1, 1, 1 } }) - // %while = f32[4]{0} while(f32[1,4]{1,0} %constant.2), - // condition=%WhileCond, - // body=%WhileBody - // %broadcast = f32[2,4]{1,0} broadcast(f32[4]{0} %while), dimensions={0} - // ROOT %add = f32[2,4]{1,0} add( - // f32[2,4]{1,0} %transpose, f32[2,4]{1,0} %broadcast) - // } - - auto module = CreateNewModule(); - const Shape r1f32 = ShapeUtil::MakeShape(F32, {4}); - const Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 4}); - - // param != 0 - // Needs 17 bytes - auto cond_builder = HloComputation::Builder("WhileCond"); - HloInstruction* cond_param = cond_builder.AddInstruction( - HloInstruction::CreateParameter(0, r1f32, "cond_param")); - HloInstruction* zero_vector = - cond_builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{0, 0, 0, 0}}))); - cond_builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(PRED, {}), HloOpcode::kNe, cond_param, zero_vector)); - auto cond_computation = module->AddEmbeddedComputation(cond_builder.Build()); - - // param - 1 - // Needs 16 bytes - auto body_builder = HloComputation::Builder("WhileBody"); - HloInstruction* body_param = body_builder.AddInstruction( - HloInstruction::CreateParameter(0, r1f32, "body_param")); - HloInstruction* one_vector = - body_builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{1, 1, 1, 1}}))); - body_builder.AddInstruction(HloInstruction::CreateBinary( - r1f32, HloOpcode::kSubtract, body_param, one_vector)); - auto body_computation = module->AddEmbeddedComputation(body_builder.Build()); - - // transpose(matrix) + bcast(while) - auto builder = HloComputation::Builder(TestName()); - HloInstruction* while_init = - builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{1, 1, 1, 1}}))); - // Creates 16 bytes, ignoring subcomputations - HloInstruction* while_loop = - builder.AddInstruction(HloInstruction::CreateWhile( - r1f32, cond_computation, body_computation, while_init)); - - // Creates 32 bytes and frees 16 - HloInstruction* bcast = builder.AddInstruction( - HloInstruction::CreateBroadcast(r2f32, while_loop, {0})); - - HloInstruction* matrix = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR2( - {{1.0, 2.0, 3.0, 4.0}, {1.0, 2.0, 3.0, 4.0}}))); - // Creates 32 bytes - HloInstruction* transpose = builder.AddInstruction( - HloInstruction::CreateTranspose(r2f32, matrix, {0, 1})); - - // Creates 32 bytes and frees 64 - HloInstruction* add = builder.AddInstruction( - HloInstruction::CreateBinary(r2f32, HloOpcode::kAdd, transpose, bcast)); - - module->AddEntryComputation(builder.Build()); - - auto size_fn = [](const BufferValue& buffer) { - return ShapeUtil::ByteSizeOf(buffer.shape()); - }; - TF_ASSERT_OK_AND_ASSIGN( - SequentialHloOrdering::HloModuleSequence sequence, - ScheduleComputationsInModule(*module, size_fn, ListMemoryScheduler)); - // Verify that all instructions are in the sequence. - auto entry_computation = module->entry_computation(); - EXPECT_EQ(entry_computation->instruction_count(), - sequence.at(entry_computation).size()); - SequentialHloOrdering ordering(module.get(), sequence); - // This schedule is an example of List's greedy heuristics being suboptimal. - // The while_loop is more expensive than transpose, so it would have been - // better to schedule it first, instead of during the busy time. - EXPECT_TRUE(ordering.ExecutesBefore(transpose, while_loop)); - EXPECT_TRUE(ordering.ExecutesBefore(transpose, bcast)); - EXPECT_TRUE(ordering.ExecutesBefore(bcast, add)); - EXPECT_TRUE(ordering.ExecutesBefore(transpose, add)); - - tensorflow::gtl::FlatMap memory_by_computation; - memory_by_computation[cond_computation] = 17; - memory_by_computation[body_computation] = 16; - std::unique_ptr points_to_analysis = - TuplePointsToAnalysis::Run(module.get()).ValueOrDie(); - - // HeapSimulator doesn't account for subcomputations - EXPECT_EQ(80, HeapSimulator::MinimumMemoryForComputation( - *entry_computation, sequence.at(entry_computation), - *points_to_analysis, size_fn) - .ValueOrDie()); - // HeapSimulator accounts for subcomputations. The output buffer is aliased, - // so we don't double count. - EXPECT_EQ(64, HeapSimulator::MinimumMemoryForComputation( - *entry_computation, sequence.at(entry_computation), - *points_to_analysis, size_fn, &memory_by_computation) - .ValueOrDie()); -} - -TEST_F(HloSchedulingTest, TuplesAreAccountedCorrectly) { - auto builder = HloComputation::Builder(TestName()); - const auto TUPLE_SIZE = 1; - const Shape r1f32 = ShapeUtil::MakeShape(xla::F32, {6}); - - // Wrap lit in abs because constants are considered free by - // IgnoreInstruction, and it skews the accounting. - auto lit = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1({1, 1, 1, 1, 1, 1}))); - auto abs_const = builder.AddInstruction( - HloInstruction::CreateUnary(r1f32, HloOpcode::kAbs, lit)); - - auto abs_abs1 = builder.AddInstruction( - HloInstruction::CreateUnary(r1f32, HloOpcode::kAbs, abs_const)); - auto tuple = builder.AddInstruction(HloInstruction::CreateTuple( - tensorflow::gtl::ArraySlice({abs_abs1}))); - auto tuple_elm = builder.AddInstruction( - HloInstruction::CreateGetTupleElement(r1f32, tuple, 0)); - - auto abs_abs2 = builder.AddInstruction( - HloInstruction::CreateUnary(r1f32, HloOpcode::kAbs, abs_const)); - - builder.AddInstruction(HloInstruction::CreateBinary(r1f32, HloOpcode::kAdd, - tuple_elm, abs_abs2)); - - auto module = CreateNewModule(); - module->AddEntryComputation(builder.Build()); - TF_ASSERT_OK_AND_ASSIGN( - SequentialHloOrdering::HloModuleSequence sequence, - ScheduleComputationsInModule(*module, - [](const BufferValue& buffer) { - return ShapeUtil::ByteSizeOf( - buffer.shape(), TUPLE_SIZE); - }, - ListMemoryScheduler)); - - // Verify that all instructions are in the sequence. - EXPECT_EQ(module->entry_computation()->instruction_count(), - sequence.at(module->entry_computation()).size()); - SequentialHloOrdering ordering(module.get(), sequence); - // tuple allocates the tuple buffer and doesn't free anything. - // abs_abs2 uses the same buffer for input/output, so its bytes-freed is 0. - // abs_abs2 should be scheduled before tuple by List. - EXPECT_TRUE(ordering.ExecutesBefore(abs_abs2, tuple)); -} - -TEST_F(HloSchedulingTest, MultiOutputFusionAccountedCorrectly) { - const Shape r1f32 = ShapeUtil::MakeShape(xla::F32, {5}); - HloComputation::Builder builder(TestName()); - - auto c1 = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1({1, 1, 1, 1, 1}))); - auto c2 = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1({1, 2, 3, 4, 5}))); - auto c3 = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1({0, 2, 4, 6, 8}))); - - auto add = builder.AddInstruction( - HloInstruction::CreateBinary(r1f32, HloOpcode::kAdd, c1, c2)); - auto mul = builder.AddInstruction( - HloInstruction::CreateBinary(r1f32, HloOpcode::kMultiply, add, c3)); - auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({add, mul})); - - auto tuple_elm = builder.AddInstruction( - HloInstruction::CreateGetTupleElement(r1f32, tuple, 0)); - - auto exp = builder.AddInstruction( - HloInstruction::CreateUnary(r1f32, HloOpcode::kExp, c3)); - - builder.AddInstruction( - HloInstruction::CreateBinary(r1f32, HloOpcode::kAdd, tuple_elm, exp)); - - auto module = CreateNewModule(); - auto* computation = module->AddEntryComputation(builder.Build()); - - auto fusion = computation->CreateFusionInstruction( - {tuple, mul, add}, HloInstruction::FusionKind::kLoop); - - TF_ASSERT_OK_AND_ASSIGN(SequentialHloOrdering::HloModuleSequence sequence, - ScheduleComputationsInModule( - *module, - [](const BufferValue& buffer) { - return ShapeUtil::ByteSizeOf(buffer.shape(), 2); - }, - ListMemoryScheduler)); - - // Verify that all instructions are in the sequence. - EXPECT_EQ(module->entry_computation()->instruction_count(), - sequence.at(module->entry_computation()).size()); - SequentialHloOrdering ordering(module.get(), sequence); - // fusion allocates memory for the tuple elements and doesn't free anything, - // so it's more expensive than exp. - EXPECT_TRUE(ordering.ExecutesBefore(exp, fusion)); -} - -TEST_F(HloSchedulingTest, HeapSimulatorAccountsForSubcomputations) { - auto module = CreateNewModule(); - const Shape r1f32 = ShapeUtil::MakeShape(F32, {4}); - - // param != 0 - // Needs 17 bytes - auto cond_builder = HloComputation::Builder("WhileCond"); - HloInstruction* cond_param = cond_builder.AddInstruction( - HloInstruction::CreateParameter(0, r1f32, "cond_param")); - HloInstruction* zero_vector = - cond_builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{0, 0, 0, 0}}))); - cond_builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(PRED, {}), HloOpcode::kNe, cond_param, zero_vector)); - auto cond_computation = module->AddEmbeddedComputation(cond_builder.Build()); - - // param - 1 - // Needs 16 bytes - auto body_builder = HloComputation::Builder("WhileBody"); - HloInstruction* body_param = body_builder.AddInstruction( - HloInstruction::CreateParameter(0, r1f32, "body_param")); - HloInstruction* one_vector = - body_builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{1, 1, 1, 1}}))); - body_builder.AddInstruction(HloInstruction::CreateBinary( - r1f32, HloOpcode::kSubtract, body_param, one_vector)); - auto body_computation = module->AddEmbeddedComputation(body_builder.Build()); - - auto builder = HloComputation::Builder(TestName()); - HloInstruction* while_init = - builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{1, 1, 1, 1}}))); - // Creates 16 bytes, ignoring subcomputations - builder.AddInstruction(HloInstruction::CreateWhile( - r1f32, cond_computation, body_computation, while_init)); - - module->AddEntryComputation(builder.Build()); - - auto size_fn = [](const BufferValue& buffer) { - return ShapeUtil::ByteSizeOf(buffer.shape()); - }; - TF_ASSERT_OK_AND_ASSIGN( - SequentialHloOrdering::HloModuleSequence sequence, - ScheduleComputationsInModule(*module, size_fn, ListMemoryScheduler)); - // Verify that all instructions are in the sequence. - auto entry_computation = module->entry_computation(); - EXPECT_EQ(entry_computation->instruction_count(), - sequence.at(entry_computation).size()); - - tensorflow::gtl::FlatMap memory_by_computation; - memory_by_computation[cond_computation] = 17; - memory_by_computation[body_computation] = 16; - std::unique_ptr points_to_analysis = - TuplePointsToAnalysis::Run(module.get()).ValueOrDie(); - - // HeapSimulator doesn't account for subcomputations - EXPECT_EQ(16, HeapSimulator::MinimumMemoryForComputation( - *entry_computation, sequence.at(entry_computation), - *points_to_analysis, size_fn) - .ValueOrDie()); - // HeapSimulator accounts for subcomputations. Cond is the largest one. - // The output buffer of the while is aliased. - EXPECT_EQ(17, HeapSimulator::MinimumMemoryForComputation( - *entry_computation, sequence.at(entry_computation), - *points_to_analysis, size_fn, &memory_by_computation) - .ValueOrDie()); -} - -TEST_F(HloSchedulingTest, UpdateScheduleUnchangedModule) { - // Updating the schedule of an unchanged HLO module should not affect the - // schedule at all. - const string module_str = R"( -HloModule UpdateScheduleUnchanged - -ENTRY main { - a = f32[] parameter(0) - b = f32[] parameter(1) - c = f32[] constant(42.0) - sum = f32[] add(a, b) - neg = f32[] negate(c) - ROOT root = f32[] multiply(sum, neg) -} -)"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseHloString(module_str)); - TF_ASSERT_OK_AND_ASSIGN( - SequentialHloOrdering::HloModuleSequence sequence, - ScheduleComputationsInModule(*module, [](const BufferValue& buffer) { - return ShapeUtil::ByteSizeOf(buffer.shape()); - })); - tensorflow::gtl::FlatMap> - id_sequence = ComputeIdSchedule(sequence); - std::vector entry_schedule = sequence.begin()->second; - - EXPECT_EQ(entry_schedule.size(), 6); - - TF_ASSERT_OK(UpdateSchedule(*module, id_sequence, &sequence)); - TF_ASSERT_OK(VerifySchedule(*module, sequence)); - - EXPECT_EQ(entry_schedule, sequence.begin()->second); -} - -TEST_F(HloSchedulingTest, UpdateScheduleWithNewInstructions) { - // Add some additional instructions to a module and verify the schedule can be - // updated. - const string module_str = R"( -HloModule UpdateScheduleWithNewInstructions - -ENTRY main { - a = f32[] parameter(0) - b = f32[] parameter(1) - c = f32[] constant(42.0) - sum = f32[] add(a, b) - neg = f32[] negate(c) - ROOT root = f32[] multiply(sum, neg) -} -)"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseHloString(module_str)); - TF_ASSERT_OK_AND_ASSIGN( - SequentialHloOrdering::HloModuleSequence sequence, - ScheduleComputationsInModule(*module, [](const BufferValue& buffer) { - return ShapeUtil::ByteSizeOf(buffer.shape()); - })); - tensorflow::gtl::FlatMap> - id_sequence = ComputeIdSchedule(sequence); - - HloComputation* entry = module->entry_computation(); - const Shape shape = entry->root_instruction()->shape(); - HloInstruction* constant = entry->AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); - HloInstruction* sub = entry->AddInstruction(HloInstruction::CreateBinary( - shape, HloOpcode::kSubtract, constant, entry->root_instruction())); - entry->set_root_instruction(sub); - - auto in_schedule = [&](const HloInstruction* hlo) { - return std::find(sequence.at(entry).begin(), sequence.at(entry).end(), - hlo) != sequence.at(entry).end(); - }; - - EXPECT_EQ(sequence.at(entry).size(), 6); - EXPECT_FALSE(in_schedule(constant)); - EXPECT_FALSE(in_schedule(sub)); - - TF_ASSERT_OK(UpdateSchedule(*module, id_sequence, &sequence)); - TF_ASSERT_OK(VerifySchedule(*module, sequence)); - - EXPECT_EQ(sequence.at(entry).size(), 8); - EXPECT_TRUE(in_schedule(constant)); - EXPECT_TRUE(in_schedule(sub)); -} - -TEST_F(HloSchedulingTest, UpdateScheduleWithAddedAndDeletedInstruction) { - // Add and delete some instructions from a module and verify that the schedule - // can be updated successfully. - const string module_str = R"( -HloModule UpdateScheduleWithAddedAndDeletedInstruction - -ENTRY main { - a = f32[] parameter(0) - b = f32[] parameter(1) - c = f32[] constant(42.0) - sum = f32[] add(a, b) - neg = f32[] negate(c) - ROOT root = f32[] multiply(sum, neg) -} -)"; - - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseHloString(module_str)); - TF_ASSERT_OK_AND_ASSIGN( - SequentialHloOrdering::HloModuleSequence sequence, - ScheduleComputationsInModule(*module, [](const BufferValue& buffer) { - return ShapeUtil::ByteSizeOf(buffer.shape()); - })); - tensorflow::gtl::FlatMap> - id_sequence = ComputeIdSchedule(sequence); - - // Set the entry root to some expression containing just a parameter and a - // constant. - HloComputation* entry = module->entry_computation(); - HloInstruction* constant = entry->AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); - HloInstruction* new_root = entry->AddInstruction( - HloInstruction::CreateBinary(constant->shape(), HloOpcode::kSubtract, - constant, entry->parameter_instruction(0))); - entry->set_root_instruction(new_root); - - // DCE should remove everything but the parameters and the newly added code. - HloDCE dce; - TF_ASSERT_OK(dce.Run(module.get()).status()); - - EXPECT_EQ(sequence.at(entry).size(), 6); - - TF_ASSERT_OK(UpdateSchedule(*module, id_sequence, &sequence)); - TF_ASSERT_OK(VerifySchedule(*module, sequence)); - - EXPECT_EQ(sequence.at(entry).size(), 4); -} - -TEST_F(HloSchedulingTest, UpdateScheduleWithCompletelyReplacedModule) { - // Completely replace a module with an entirely new set of instructions and - // verify that the schedule can be updated successfully. - const string module_str = R"( -HloModule UpdateScheduleWithCompletelyReplacedModule - -ENTRY main { - a = f32[] constant(42.0) - b = f32[] constant(123.0) - ROOT sum = f32[] add(a, b) -} -)"; - - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseHloString(module_str)); - TF_ASSERT_OK_AND_ASSIGN( - SequentialHloOrdering::HloModuleSequence sequence, - ScheduleComputationsInModule(*module, [](const BufferValue& buffer) { - return ShapeUtil::ByteSizeOf(buffer.shape()); - })); - tensorflow::gtl::FlatMap> - id_sequence = ComputeIdSchedule(sequence); - - // Replace the entry computation with the negation of a constant. - HloComputation* entry = module->entry_computation(); - HloInstruction* constant = entry->AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); - HloInstruction* new_root = entry->AddInstruction(HloInstruction::CreateUnary( - constant->shape(), HloOpcode::kNegate, constant)); - entry->set_root_instruction(new_root); - - // DCE the old instructions. - HloDCE dce; - TF_ASSERT_OK(dce.Run(module.get()).status()); - - EXPECT_EQ(sequence.at(entry).size(), 3); - - TF_ASSERT_OK(UpdateSchedule(*module, id_sequence, &sequence)); - TF_ASSERT_OK(VerifySchedule(*module, sequence)); - - EXPECT_EQ(sequence.at(entry).size(), 2); -} - -TEST_F(HloSchedulingTest, UpdateScheduleWithMultipleComputations) { - // Create changes to more than one computation in an HLO module and verify - // that the schedule can be updated. - const string module_str = R"( -HloModule UpdateScheduleWithMultipleComputations - -%Body (param.1: (s32[], token[])) -> (s32[], token[]) { - %param.1 = (s32[], token[]) parameter(0) - %get-tuple-element.1 = s32[] get-tuple-element((s32[], token[]) %param.1), index=0 - %constant.1 = s32[] constant(1) - %add = s32[] add(s32[] %get-tuple-element.1, s32[] %constant.1) - %get-tuple-element.2 = token[] get-tuple-element((s32[], token[]) %param.1), index=1 - %after-all = token[] after-all(token[] %get-tuple-element.2) - ROOT %tuple = (s32[], token[]) tuple(s32[] %add, token[] %after-all) -} - -%Cond (param: (s32[], token[])) -> pred[] { - %param = (s32[], token[]) parameter(0) - %get-tuple-element = s32[] get-tuple-element((s32[], token[]) %param), index=0 - %constant = s32[] constant(42) - ROOT %less-than = pred[] less-than(s32[] %get-tuple-element, s32[] %constant) -} - -ENTRY %WhileLoop () -> s32[] { - %zero = s32[] constant(0) - %init_token = token[] after-all() - %init_tuple = (s32[], token[]) tuple(s32[] %zero, token[] %init_token) - %while = (s32[], token[]) while((s32[], token[]) %init_tuple), condition=%Cond, body=%Body - ROOT %root = s32[] get-tuple-element((s32[], token[]) %while), index=0 -} -)"; - - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseHloString(module_str)); - TF_ASSERT_OK_AND_ASSIGN( - SequentialHloOrdering::HloModuleSequence sequence, - ScheduleComputationsInModule(*module, [](const BufferValue& buffer) { - return ShapeUtil::ByteSizeOf(buffer.shape(), - /*pointer_size=*/sizeof(void*)); - })); - tensorflow::gtl::FlatMap> - id_sequence = ComputeIdSchedule(sequence); - - const HloInstruction* xla_while = - module->entry_computation()->root_instruction()->operand(0); - HloComputation* body = xla_while->while_body(); - HloComputation* cond = xla_while->while_condition(); - - // Negate the root of the cond. - cond->set_root_instruction(cond->AddInstruction( - HloInstruction::CreateUnary(ShapeUtil::MakeShape(PRED, {}), - HloOpcode::kNot, cond->root_instruction()))); - - // Replace the body with a computation which just passes through its - // parameter. - body->set_root_instruction(body->parameter_instruction(0)); - - // DCE the dead code in the body. - HloDCE dce; - TF_ASSERT_OK(dce.Run(module.get()).status()); - - EXPECT_EQ(sequence.at(body).size(), 7); - EXPECT_EQ(sequence.at(cond).size(), 4); - - TF_ASSERT_OK(UpdateSchedule(*module, id_sequence, &sequence)); - TF_ASSERT_OK(VerifySchedule(*module, sequence)); - - EXPECT_EQ(sequence.at(body).size(), 1); - EXPECT_EQ(sequence.at(cond).size(), 5); -} - -} // namespace -} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc index 980dae07ceec20a945f7db5f1377c6f5c08af47a..70a860c356ca2fb1c4c973ea3d96c50fabc2c7c2 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding.cc @@ -17,6 +17,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" +#include "tensorflow/compiler/xla/overflow_util.h" #include "tensorflow/core/lib/core/errors.h" namespace xla { @@ -54,9 +55,8 @@ HloSharding HloSharding::Tuple(const ShapeTree& sub_shardings) { return HloSharding(flattened_list); } -HloSharding HloSharding::Tuple( - const Shape& tuple_shape, - tensorflow::gtl::ArraySlice shardings) { +HloSharding HloSharding::Tuple(const Shape& tuple_shape, + absl::Span shardings) { CHECK(ShapeUtil::IsTuple(tuple_shape)) << ShapeUtil::HumanString(tuple_shape); for (auto& sharding : shardings) { CHECK(!sharding.IsTuple()) << sharding.ToString(); @@ -142,7 +142,7 @@ std::vector HloSharding::TileIndexForDevice(int64 device) const { CHECK(!maximal_); CHECK(!IsTuple()); std::vector ret_index; - tile_assignment_.Each([&](tensorflow::gtl::ArraySlice index, int64 d) { + tile_assignment_.Each([&](absl::Span index, int64 d) { if (d == device) { ret_index = {index.begin(), index.end()}; } @@ -151,8 +151,7 @@ std::vector HloSharding::TileIndexForDevice(int64 device) const { return ret_index; } -int64 HloSharding::DeviceForTileIndex( - tensorflow::gtl::ArraySlice index) const { +int64 HloSharding::DeviceForTileIndex(absl::Span index) const { CHECK(!replicated_); CHECK(!IsTuple()); if (maximal_) { @@ -319,7 +318,7 @@ Status HloSharding::ValidateNonTuple(const Shape& shape, Status status = Status::OK(); std::set seen_cores; tile_assignment_.Each( - [&](tensorflow::gtl::ArraySlice indices, int32 core) { + [&](absl::Span indices, int32 core) { // Don't overwrite a bad status, so we report the first error. if (status.ok()) { if (core >= num_devices) { @@ -371,10 +370,28 @@ Status HloSharding::ValidateNonTuple(const Shape& shape, return HloSharding(tuple_shardings); } else if (proto.type() == OpSharding::Type::OpSharding_Type_REPLICATED) { return Replicate(); - } else if (proto.type() == OpSharding::Type::OpSharding_Type_MAXIMAL || - proto.tile_assignment_devices().size() == 1) { + } else if (proto.tile_assignment_devices().size() == 1) { return HloSharding(proto.tile_assignment_devices(0)); } + + TF_RET_CHECK(proto.type() != OpSharding::Type::OpSharding_Type_MAXIMAL) + << "Maximal sharding is expected to have single device assignment, but " + << proto.tile_assignment_devices().size() << " has provided."; + + TF_RET_CHECK(proto.tile_assignment_devices().size() > 1); + TF_RET_CHECK(!proto.tile_assignment_dimensions().empty()); + + // RE: the product of tile assignment tensor dimensions must be + // equal to tile_assignment_devices.size(). + int64 product_of_dimensions = 1; + for (auto dimension : proto.tile_assignment_dimensions()) { + TF_RET_CHECK(dimension > 0); + product_of_dimensions = + MultiplyWithoutOverflow(product_of_dimensions, dimension); + TF_RET_CHECK(product_of_dimensions > 0); + } + TF_RET_CHECK(product_of_dimensions == proto.tile_assignment_devices().size()); + // Some versions of gcc cannot infer the TileAssignment constructor from a // braced initializer-list, so create one manually. std::vector devices(proto.tile_assignment_devices().begin(), @@ -429,18 +446,32 @@ Shape HloSharding::TileShape(const Shape& shape) const { HloSharding HloSharding::GetSubSharding(const Shape& shape, const ShapeIndex& index) const { CHECK(IsTuple()); - - Shape sub_shape = ShapeUtil::GetSubshape(shape, index); - ShapeTree sub_shape_tree(sub_shape, Replicate()); - sub_shape_tree.CopySubtreeFrom(GetAsShapeTree(shape), index, {}); - return ShapeUtil::IsTuple(sub_shape) ? Tuple(sub_shape_tree) - : sub_shape_tree.element(ShapeIndex({})); + int64 sharding_index = 0; + const Shape* sub_shape = &shape; + for (int64 idx : index) { + for (int64 i = 0; i < idx; ++i) { + sharding_index += + ShapeUtil::GetLeafCount(ShapeUtil::GetSubshape(*sub_shape, {i})); + } + sub_shape = &ShapeUtil::GetSubshape(*sub_shape, {idx}); + } + if (ShapeUtil::IsTuple(*sub_shape)) { + auto begin_it = tuple_elements_.begin() + sharding_index; + std::vector sub_shardings( + begin_it, begin_it + ShapeUtil::GetLeafCount(*sub_shape)); + return HloSharding::Tuple(*sub_shape, sub_shardings); + } else { + return tuple_elements_[sharding_index]; + } } absl::optional HloSharding::ExtractSingleSharding() const { if (!IsTuple()) { return *this; } + if (tuple_elements_.empty()) { + return absl::nullopt; + } for (int64 i = 1; i < tuple_elements_.size(); ++i) { if (tuple_elements_[0] != tuple_elements_[i]) { return absl::nullopt; diff --git a/tensorflow/compiler/xla/service/hlo_sharding.h b/tensorflow/compiler/xla/service/hlo_sharding.h index be51c3f55b59aa65dbb15210b494a5e795f0cd3e..9775505f8608ced3e33abe376f4922cc6a972726 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.h +++ b/tensorflow/compiler/xla/service/hlo_sharding.h @@ -23,12 +23,12 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/array.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/protobuf_util.h" #include "tensorflow/compiler/xla/shape_tree.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" @@ -66,7 +66,7 @@ class HloSharding { // shardings must match the number of leaf nodes in tuple_shape. For // empty tuples, the shardings array must have one element. static HloSharding Tuple(const Shape& tuple_shape, - tensorflow::gtl::ArraySlice shardings); + absl::Span shardings); // Creates a new sharding for a tuple type, with a single input sharding // repeated on each leaf. @@ -132,7 +132,7 @@ class HloSharding { // Returns the device that should execute the given tile. // It is an error to call this if is_replicated() is true. // REQUIRES: !IsTuple() - int64 DeviceForTileIndex(tensorflow::gtl::ArraySlice index) const; + int64 DeviceForTileIndex(absl::Span index) const; // Given a device ID, returns the offset within the specified shape of the // tile that should be executed on the given core. This returns the lower diff --git a/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc index a9b3b66934bc6feb0b114d25b1cc8b4e613ff3be..e3f4a9852ace86c20610362aa6ad3c3d9c78de30 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc @@ -24,6 +24,23 @@ namespace xla { namespace { +// AssignmentKind and kUnassignedDevice are used during tuple domain sharding +// propagation in order to distinguish among three cases: +// kUnassigned: no assignment has occurred +// kAssigned: at least an assignment has occurred +// kConflict: no assignment has occurred because of conflicting propagations, +// which occurs when multiple users of an instruction have different +// shardings. +enum class AssignmentKind { kUnassigned, kAssigned, kConflict }; + +// kUnassignedDevice can only be assigned to tuple leaf shardings to indicate +// absence of sharding information for that particular sub-sharding during +// sharding propagation. It is used to be able to express tuple shardings with +// partial information. At the end of the propagation the sharding of +// tuple-shaped instructions using kUnassignedDevice's is cleared. +// TODO(b/112883246): Centralized enum of reserved devices. +constexpr int64 kUnassignedDevice = -2; + struct PassThrough { PassThrough(HloInstruction* user, HloInstruction* operand) : user(user), operand(operand) {} @@ -147,108 +164,174 @@ Status ApplyDomainSingleSharding(const DomainMetadata::Domain& domain, return Status::OK(); } -// Retrieves the sharding of a tuple shaped instruction in form of a ShapeTree. -// If the instruction has no sharding, a ShapeTree with HloSharding::Replicate() -// sharding will be returned. -ShapeTree GetTupleSharding(HloInstruction* tuple) { - if (tuple->has_sharding()) { - return tuple->sharding().GetAsShapeTree(tuple->shape()); +// Return the ShapeTree of the user argument. The user argument +// is assumed to be a user of the instruction argument. +// If user is a tuple instruction, return the tuple subsharding corresponding to +// the operand matching the instruction argument, because that is the +// subsharding corresponding to instruction. +ShapeTree GetShardingTreeFromUser( + const HloInstruction& instruction, const HloInstruction& user) { + if (user.opcode() == HloOpcode::kTuple) { + return user.sharding() + .GetSubSharding(user.shape(), {user.operand_index(&instruction)}) + .GetAsShapeTree(instruction.shape()); + } + return user.sharding().GetAsShapeTree(user.shape()); +} + +// Assign rhs to lhs. If rhs is unassigned (assigned to kUnassignedDevice) +// then no assignment is made. Therefore kUnassignedDevice is never propagated. +// kConflict is returned if lhs is already assigned and rhs is assigned to a +// different device. +StatusOr AssignLeafSharding(HloSharding* lhs, + const HloSharding& rhs) { + TF_RET_CHECK(!lhs->IsTuple() && !rhs.IsTuple()); + if (rhs.UsesDevice(kUnassignedDevice)) { + return AssignmentKind::kUnassigned; + } + if (lhs->UsesDevice(kUnassignedDevice)) { + *lhs = rhs; + return AssignmentKind::kAssigned; } - return ShapeTree(tuple->shape(), HloSharding::Replicate()); + return lhs->UniqueDevice() != rhs.UniqueDevice() + ? AssignmentKind::kConflict + : AssignmentKind::kUnassigned; +} + +// Assigns the whole rhs tree to lhs_tree, starting at lhs_it. +// In case of conflicting assignment AssignmentKind::kConflict is returned. In +// this case lhs_tree is partially assigned, up to the conflicting leaf. It is +// up to the caller to discard the partial assignment in case of conflict. +StatusOr AssignTreeSharding( + ShapeTree* lhs_tree, ShapeTree::iterator lhs_it, + const ShapeTree& rhs_tree) { + AssignmentKind assigned = AssignmentKind::kUnassigned; + auto rhs_it = rhs_tree.begin(); + for (; lhs_it != lhs_tree->end() && rhs_it != rhs_tree.end(); + ++lhs_it, ++rhs_it) { + // TODO(b/112885211): Add ShapeTree::IsLeaf(const ShapeTreeIterator &it) + if (rhs_tree.IsLeaf(rhs_it->first)) { + TF_RET_CHECK(lhs_tree->IsLeaf(lhs_it->first)); + TF_ASSIGN_OR_RETURN(AssignmentKind sub_assigned, + AssignLeafSharding(&lhs_it->second, rhs_it->second)); + if (sub_assigned == AssignmentKind::kConflict) { + // In case of conflict we return conflict to the caller. At this point + // partial assignments to lhs_tree may have been made already. It is up + // to the caller to discard the partial assignment in case of conflict. + return AssignmentKind::kConflict; + } else if (sub_assigned == AssignmentKind::kAssigned) { + assigned = sub_assigned; + } + } + } + TF_RET_CHECK(rhs_it == rhs_tree.end()); + return assigned; } -// Retrieves the sharding of operand, asked from a user instruction which is -// within domain. If operand is a kDomain, it means that sharding argument is -// the operand sharding, otherwise the operand's own sharding will be returned. -const HloSharding* GetOperandSharding(const HloInstruction* operand, +StatusOr ApplyShardingFromUsers(HloInstruction* instruction, const DomainMetadata::Domain& domain, - const HloSharding& sharding) { - // Here the user of operand is within the domain instruction set, and since it - // is user of operand, we need to look into the enter_domains set. If this is - // not a kDomain within the user domains set, then return the operand - // sharding, if any. - if (operand->opcode() != HloOpcode::kDomain || - domain.enter_domains.count(const_cast(operand)) == 0) { - return operand->has_sharding() ? &operand->sharding() : nullptr; + const HloSharding& domain_sharding) { + if (instruction->users().empty()) { + // No sharding from users, use domain_sharding, after checking + // compatibility. + TF_RET_CHECK(ShapeUtil::IsTuple(instruction->shape()) && + ShapeUtil::GetLeafCount(instruction->shape()) == + domain_sharding.tuple_elements().size()); + instruction->set_sharding(domain_sharding); + return true; + } + AssignmentKind assigned = AssignmentKind::kUnassigned; + // The sharding_tree leaves are initialized to kUnassignedDevice. Only Tuple + // subshardings can result in a final sharding assignment containing + // kUnassignedDevice leaves, in case some tuple indexes are not used, or are + // used by users that don't have a sharding. + // Non-tuple shardings are either assigned to a real sharding, or are not + // assigned at all. As such they will never get assigned to kUnassignedDevice. + // In any case, kUnassignedDevice is never propagated, from the implementation + // of AssignLeafSharding. + ShapeTree sharding_tree( + instruction->shape(), HloSharding::AssignDevice(kUnassignedDevice)); + for (HloInstruction* user : instruction->users()) { + if (user->opcode() == HloOpcode::kDomain && + domain.exit_domains.count(const_cast(user)) > 0) { + // If a user is a domain and it is registered in the domain exits, then + // the instruction sharding is taken directly from the domain, and no + // further users need to be visited. + instruction->set_sharding(domain_sharding); + return true; + } + if (!user->has_sharding()) { + continue; + } + AssignmentKind sub_assigned = AssignmentKind::kUnassigned; + ShapeTree user_sharding_tree = + GetShardingTreeFromUser(*instruction, *user); + if (ShapeUtil::IsTuple(instruction->shape())) { + // For tuple-shaped instructions collect individual tuple subshardings + // from the uses, and then combine them into the tuple sharding. + // If the user is a GTE its sharding concerns only the subtree of + // sharding_tree at index user->tuple_index, otherwise the whole + // sharding_tree is affected. + ShapeTree::iterator sharding_tree_begin = + user->opcode() == HloOpcode::kGetTupleElement + ? sharding_tree.find({user->tuple_index()}) + : sharding_tree.begin(); + TF_ASSIGN_OR_RETURN( + sub_assigned, AssignTreeSharding(&sharding_tree, sharding_tree_begin, + user_sharding_tree)); + } else { + // Non-tuple shape: assign common users sharding. + TF_RET_CHECK(user_sharding_tree.leaf_count() == 1) + << "Expected non-tuple user sharding"; + TF_ASSIGN_OR_RETURN( + sub_assigned, + AssignTreeSharding(&sharding_tree, sharding_tree.begin(), + user_sharding_tree)); + } + + if (sub_assigned == AssignmentKind::kConflict) { + // In case of conflict we don't assign any sharding. + return false; + } else if (sub_assigned == AssignmentKind::kAssigned) { + assigned = sub_assigned; + } + } + + if (assigned == AssignmentKind::kAssigned) { + if (ShapeUtil::IsTuple(instruction->shape())) { + instruction->set_sharding(HloSharding::Tuple(sharding_tree)); + } else { + TF_RET_CHECK(sharding_tree.leaf_count() == 1); + instruction->set_sharding(sharding_tree.leaf_begin()->second); + } + return true; } - // At this point operand is a kDomain of the currently processed domain, so we - // can refer to sharding as the domain sharding. - return &sharding; + return false; } // Tries to propagate the sharding information into the instructions that are -// part of the domain, in a post order manner (operand propagate to user). +// part of the domain, in a reverse post order manner (users propoagate to +// instruction). StatusOr ApplyDomainShardingPass(const DomainMetadata::Domain& domain, - const HloSharding& sharding) { + const HloSharding& domain_sharding) { int64 assigned = 0; - for (HloInstruction* instruction : domain.instructions) { + // domain.instructions are ordered in a post-order manner. As we do + // user->operand propagation we process instructions in reverse order. In so + // doing we are guaranteed to process all users before their operands. + for (auto it = domain.instructions.rbegin(); it != domain.instructions.rend(); + ++it) { + HloInstruction* instruction = *it; if (instruction->has_sharding()) { continue; } - if (instruction->opcode() == HloOpcode::kGetTupleElement) { - HloInstruction* tuple = instruction->mutable_operand(0); - const HloSharding* tuple_sharding = - GetOperandSharding(tuple, domain, sharding); - if (tuple_sharding != nullptr) { - if (tuple_sharding->IsTuple()) { - HloSharding sub_sharding = tuple_sharding->GetSubSharding( - tuple->shape(), {instruction->tuple_index()}); - VLOG(4) << " " << instruction->name() << " to sharding " - << sub_sharding; - instruction->set_sharding(sub_sharding); - } else { - SetSingleSharding(instruction, *tuple_sharding); - } - ++assigned; - } - } else if (instruction->opcode() == HloOpcode::kTuple) { - int64 tuple_assigned = 0; - ShapeTree shape_tree = GetTupleSharding(instruction); - for (int64 i = 0; i < instruction->operand_count(); ++i) { - const HloSharding* operand_sharding = - GetOperandSharding(instruction->operand(i), domain, sharding); - if (operand_sharding != nullptr) { - HloSharding operand_subsharding = HloSharding::Replicate(); - if (operand_sharding == &sharding) { - operand_subsharding = - sharding.GetSubSharding(instruction->shape(), {i}); - operand_sharding = &operand_subsharding; - } - if (shape_tree.element({i}) != *operand_sharding) { - *shape_tree.mutable_element({i}) = *operand_sharding; - ++tuple_assigned; - } - } - } - if (tuple_assigned > 0) { - HloSharding tuple_sharding = HloSharding::Tuple(shape_tree); - VLOG(4) << " " << instruction->name() << " to sharding " - << tuple_sharding; - instruction->set_sharding(tuple_sharding); - ++assigned; - } - } else { - // If all the operand of the given instruction has the same single device - // assignment, assign that device to this instruction as well. - const HloSharding* common_sharding = nullptr; - for (const HloInstruction* operand : instruction->operands()) { - const HloSharding* operand_sharding = - GetOperandSharding(operand, domain, sharding); - if (operand_sharding != nullptr) { - if (common_sharding != nullptr && - *common_sharding != *operand_sharding) { - common_sharding = nullptr; - break; - } - common_sharding = operand_sharding; - } - } - if (common_sharding != nullptr) { - VLOG(4) << " " << instruction->name() << " to sharding " - << *common_sharding; - instruction->set_sharding(*common_sharding); - ++assigned; - } + // Take the sharding from the users. + TF_ASSIGN_OR_RETURN( + bool instruction_assigned, + ApplyShardingFromUsers(instruction, domain, domain_sharding)); + if (instruction_assigned) { + ++assigned; + VLOG(4) << " " << instruction->name() << " to sharding " + << instruction->sharding(); } } return assigned; @@ -266,18 +349,22 @@ Status ApplyDomainSharding(const DomainMetadata::Domain& domain, return ApplyDomainSingleSharding(domain, *single_sharding); } VLOG(1) << "Assigning non-trivial sharding " << sharding; - for (;;) { - TF_ASSIGN_OR_RETURN(int64 assigned, - ApplyDomainShardingPass(domain, sharding)); - if (assigned == 0) { - break; - } - } + TF_RETURN_IF_ERROR(ApplyDomainShardingPass(domain, sharding).status()); + int64 unassigned = 0; for (HloInstruction* instruction : domain.instructions) { if (!instruction->has_sharding()) { LOG(WARNING) << "Unassigned instruction: " << instruction->ToString(); ++unassigned; + } else { + // Un-set sharding of tuples whose sub-sgardings are assigned to + // kUnassignedDevice. Indeed in case of doubt it is better to leave the + // entire tuple unassigned, and let the device placer decide for it. + if (instruction->sharding().UsesDevice(kUnassignedDevice)) { + TF_RET_CHECK(ShapeUtil::IsTuple(instruction->shape())) + << "Only tuples can have kUnassignedDevice sub shardings"; + instruction->clear_sharding(); + } } } // Should we error out if unassigned > 0? @@ -285,7 +372,7 @@ Status ApplyDomainSharding(const DomainMetadata::Domain& domain, } StatusOr> ExtractOriginalCommonSharding( - tensorflow::gtl::ArraySlice instructions) { + absl::Span instructions) { // If we are here, all the instructions being passed had the same sharding // (or no sharding), by the means of the ShardingMatches() API. // As such, no kDomain was inserted, and here we are asked to extract the @@ -335,6 +422,13 @@ bool ShardingMetadata::Matches(const DomainMetadata& other) const { : false; } +size_t ShardingMetadata::Hash() const { + if (sharding_ != nullptr) { + return sharding_->Hash(); + } + return static_cast(0x297814aaad196e6dULL); +} + string ShardingMetadata::ToString() const { return sharding_ != nullptr ? sharding_->ToString() : "{}"; } diff --git a/tensorflow/compiler/xla/service/hlo_sharding_metadata.h b/tensorflow/compiler/xla/service/hlo_sharding_metadata.h index 7a6b0d9abcbf1f8206654fc66e6dd99f82696556..e3ae82a070643895f2ecac0e64073a88b592f7c1 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding_metadata.h +++ b/tensorflow/compiler/xla/service/hlo_sharding_metadata.h @@ -16,11 +16,11 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_METADATA_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_METADATA_H_ +#include "absl/types/span.h" #include "tensorflow/compiler/xla/service/hlo_domain_metadata.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_sharding.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/array_slice.h" namespace xla { @@ -36,6 +36,8 @@ class ShardingMetadata : public DomainMetadata { bool Matches(const DomainMetadata& other) const override; + size_t Hash() const override; + string ToString() const override; const HloSharding* sharding() const { return sharding_.get(); } diff --git a/tensorflow/compiler/xla/service/hlo_sharding_test.cc b/tensorflow/compiler/xla/service/hlo_sharding_test.cc index 2341f8ada0dba4e5a5f39e991498a2ee44303dbd..80634677e78e4a35dcb9bf7de018a88122c3c030 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding_test.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding_test.cc @@ -29,8 +29,8 @@ limitations under the License. namespace xla { namespace { -Array MakeArray(tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice contents) { +Array MakeArray(absl::Span dimensions, + absl::Span contents) { Array a(dimensions); std::copy(contents.begin(), contents.end(), a.begin()); return a; diff --git a/tensorflow/compiler/xla/service/hlo_subcomputation_unification.h b/tensorflow/compiler/xla/service/hlo_subcomputation_unification.h index d1cf644f8273e632e2952cca0da749616e9b6233..fa34bddde1a47b520f7f96361d155e4017e44e60 100644 --- a/tensorflow/compiler/xla/service/hlo_subcomputation_unification.h +++ b/tensorflow/compiler/xla/service/hlo_subcomputation_unification.h @@ -22,7 +22,7 @@ namespace xla { // Unify subcomputations of a `HloModule`: if any computations are equal, choose // one arbitrarily to use and delete the others. -class HloSubcomputationUnification : public HloPassInterface { +class HloSubcomputationUnification : public HloModulePass { public: absl::string_view name() const override { return "subcomputation-unification"; diff --git a/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc b/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc index 1e2b31a1f2bb4865faafc3d14e2b194e3aa171a1..6fd734a2b9e6c8c9fca76a944ca3df4c3b8a212f 100644 --- a/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc +++ b/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc @@ -14,7 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/hlo_tfgraph_builder.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/tensor_shape.pb.h" @@ -24,7 +24,7 @@ namespace { using ::tensorflow::GraphDef; -class HloTfGraphBuilderTest : public HloTestBase { +class HloTfGraphBuilderTest : public HloVerifiedTestBase { protected: HloTfGraphBuilderTest() {} HloTfGraphBuilder generator_; diff --git a/tensorflow/compiler/xla/service/hlo_value.cc b/tensorflow/compiler/xla/service/hlo_value.cc index e0c13261772cf7eb9f71cd02182dc3166ba172ed..59594ab2f0f70a206c73e998dbfa69c2c5c7ba43 100644 --- a/tensorflow/compiler/xla/service/hlo_value.cc +++ b/tensorflow/compiler/xla/service/hlo_value.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" @@ -31,7 +32,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -131,6 +131,7 @@ bool MayUseOperandValue(int64 operand_number, const ShapeIndex& index, CHECK_LE(operand_number, 2); return operand_number == 0 || index.empty(); + case HloOpcode::kDomain: case HloOpcode::kTuple: // These instructions always pass through their operands transparently. return false; @@ -149,7 +150,7 @@ bool MayUseOperandValue(int64 operand_number, const ShapeIndex& index, } // namespace void HloValue::SetPositionsAndComputeUses( - tensorflow::gtl::ArraySlice positions) { + absl::Span positions) { CHECK_EQ(positions_.size(), 1) << "SetPositions should only be called once."; // The positions must be unique and should not contain the defining position @@ -166,7 +167,7 @@ void HloValue::SetPositionsAndComputeUses( positions_.insert(positions_.end(), positions.begin(), positions.end()); // Gather the computation roots at which this value appears. - tensorflow::gtl::FlatSet root_positions; + absl::flat_hash_set root_positions; for (const HloPosition& position : positions_) { if (position.instruction == position.instruction->parent()->root_instruction()) { @@ -222,8 +223,7 @@ string HloValueSet::ToString() const { })); } -bool HloValueSet::AssignUnionOf( - tensorflow::gtl::ArraySlice inputs) { +bool HloValueSet::AssignUnionOf(absl::Span inputs) { HloValueSet union_set; for (const HloValueSet* input : inputs) { for (const HloValue* value : input->values()) { @@ -254,7 +254,7 @@ std::ostream& operator<<(std::ostream& out, const HloValueSet& value_set) { } bool InstructionValueSet::AssignUnionOf( - tensorflow::gtl::ArraySlice inputs) { + absl::Span inputs) { CHECK_GT(inputs.size(), 0); for (int i = 1; i < inputs.size(); ++i) { DCHECK(ShapeUtil::Compatible(inputs[0]->shape(), inputs[i]->shape())); diff --git a/tensorflow/compiler/xla/service/hlo_value.h b/tensorflow/compiler/xla/service/hlo_value.h index a1151f65e07dffdcd52f645f61dcc9b4f26459c0..b6670d409b92e8be42f5cdb40fba8d662ae83958 100644 --- a/tensorflow/compiler/xla/service/hlo_value.h +++ b/tensorflow/compiler/xla/service/hlo_value.h @@ -20,13 +20,13 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/service/buffer_value.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/shape_tree.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" @@ -108,8 +108,7 @@ class HloValue : public BufferValue { // Sets the positions in the module at which the HloValue appears. Updates // uses. Should be called once and only once. The defining position should not // be included in 'positions' as this is set at construction time. - void SetPositionsAndComputeUses( - tensorflow::gtl::ArraySlice positions); + void SetPositionsAndComputeUses(absl::Span positions); // Returns whether this value is a phi value. bool is_phi() const { return is_phi_; } @@ -186,14 +185,14 @@ class HloValueSet { public: HloValueSet() = default; - explicit HloValueSet(tensorflow::gtl::ArraySlice values) + explicit HloValueSet(absl::Span values) : values_(values.begin(), values.end()) { SortAndUniquifyValues(); } // Sets this value set to the union of the given value sets. Returns whether // this value set changed. - bool AssignUnionOf(tensorflow::gtl::ArraySlice inputs); + bool AssignUnionOf(absl::Span inputs); // Return the vector of HloValues in the set. Values in the vector are unique // and stably sorted by value id. @@ -247,8 +246,7 @@ class InstructionValueSet : public ShapeTree { // Sets this value set to the union of the given value sets. Returns whether // this value set changed. - bool AssignUnionOf( - tensorflow::gtl::ArraySlice inputs); + bool AssignUnionOf(absl::Span inputs); string ToString() const; }; diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index f60c4eab4270e419642ced71d041db0127a9c74d..ba95cef21da404646c3d347d3599209ce0a7f987 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -15,17 +15,35 @@ limitations under the License. #include +#include "absl/container/flat_hash_map.h" #include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_verifier.h" #include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/gtl/flatmap.h" namespace xla { +Status ShapeVerifier::Preprocess(HloInstruction* hlo) { + if (LayoutUtil::IsSparseArray(hlo->shape())) { + return InternalError("Sparse arrays are not yet fully supported: %s", + hlo->ToString()); + } + return Status::OK(); +} + +static Status CheckOperandCount(const HloInstruction* hlo, int expected) { + if (hlo->operand_count() != expected) { + return InternalError("Expected %d operands for %s instruction: %s", + expected, HloOpcodeString(hlo->opcode()), + hlo->ToString()); + } + return Status::OK(); +} + Status ShapeVerifier::HandleElementwiseUnary(HloInstruction* hlo) { return CheckUnaryShape(hlo); } @@ -57,12 +75,14 @@ Status ShapeVerifier::HandleConcatenate(HloInstruction* concatenate) { } Status ShapeVerifier::HandleConvert(HloInstruction* convert) { + TF_RETURN_IF_ERROR(CheckOperandCount(convert, 1)); return CheckShape(convert, ShapeInference::InferConvertShape( convert->operand(0)->shape(), convert->shape().element_type())); } Status ShapeVerifier::HandleBitcastConvert(HloInstruction* convert) { + TF_RETURN_IF_ERROR(CheckOperandCount(convert, 1)); return CheckShape(convert, ShapeInference::InferBitcastConvertShape( convert->operand(0)->shape(), convert->shape().element_type())); @@ -73,6 +93,7 @@ Status ShapeVerifier::HandleCopy(HloInstruction* copy) { } Status ShapeVerifier::HandleDot(HloInstruction* dot) { + TF_RETURN_IF_ERROR(CheckOperandCount(dot, 2)); TF_ASSIGN_OR_RETURN(const Shape expected, ShapeInference::InferDotOpShape( dot->operand(0)->shape(), dot->operand(1)->shape(), @@ -81,16 +102,18 @@ Status ShapeVerifier::HandleDot(HloInstruction* dot) { } Status ShapeVerifier::HandleConvolution(HloInstruction* convolution) { + TF_RETURN_IF_ERROR(CheckOperandCount(convolution, 2)); TF_ASSIGN_OR_RETURN( const Shape expected, ShapeInference::InferConvolveShape( convolution->operand(0)->shape(), convolution->operand(1)->shape(), - convolution->window(), convolution->convolution_dimension_numbers(), - convolution->feature_group_count())); + convolution->feature_group_count(), convolution->window(), + convolution->convolution_dimension_numbers())); return CheckShape(convolution, expected); } Status ShapeVerifier::HandleFft(HloInstruction* fft) { + TF_RETURN_IF_ERROR(CheckOperandCount(fft, 1)); TF_ASSIGN_OR_RETURN( const Shape expected, ShapeInference::InferFftShape(fft->operand(0)->shape(), fft->fft_type(), @@ -116,7 +139,14 @@ Status ShapeVerifier::HandleAllToAll(HloInstruction* hlo) { ShapeInference::InferAllToAllTupleShape(operand_shapes)); } +Status ShapeVerifier::HandleCollectivePermute(HloInstruction* hlo) { + TF_RETURN_IF_ERROR(CheckOperandCount(hlo, 1)); + return CheckShape(hlo, ShapeInference::InferCollectivePermuteShape( + hlo->operand(0)->shape())); +} + Status ShapeVerifier::HandleReducePrecision(HloInstruction* reduce_precision) { + TF_RETURN_IF_ERROR(CheckOperandCount(reduce_precision, 1)); return CheckShape(reduce_precision, ShapeInference::InferReducePrecisionShape( reduce_precision->operand(0)->shape(), reduce_precision->exponent_bits(), @@ -128,10 +158,9 @@ Status ShapeVerifier::CheckIsTokenOperand(const HloInstruction* instruction, const HloInstruction* token = instruction->operand(operand_no); if (!ShapeUtil::Equal(token->shape(), ShapeUtil::MakeTokenShape())) { return InternalError( - "Expected operand %lld to be token-shaped, actual shape is " + "Expected operand %d to be token-shaped, actual shape is " "%s:\n%s", - operand_no, StringifyShape(token->shape()).c_str(), - instruction->ToString().c_str()); + operand_no, StringifyShape(token->shape()), instruction->ToString()); } return Status::OK(); } @@ -144,14 +173,14 @@ Status ShapeVerifier::CheckOperandAndParameter( computation->parameter_instruction(parameter_number); if (!ShapesSame(operand->shape(), parameter->shape())) { return InternalError("Operand %s shape does not match parameter's %s in %s", - operand->ToString().c_str(), - parameter->ToString().c_str(), - instruction->ToString().c_str()); + operand->ToString(), parameter->ToString(), + instruction->ToString()); } return Status::OK(); } Status ShapeVerifier::HandleInfeed(HloInstruction* instruction) { + TF_RETURN_IF_ERROR(CheckOperandCount(instruction, 1)); HloInfeedInstruction* infeed = Cast(instruction); TF_RETURN_IF_ERROR(CheckIsTokenOperand(instruction, 0)); @@ -162,6 +191,7 @@ Status ShapeVerifier::HandleInfeed(HloInstruction* instruction) { } Status ShapeVerifier::HandleOutfeed(HloInstruction* instruction) { + TF_RETURN_IF_ERROR(CheckOperandCount(instruction, 2)); HloOutfeedInstruction* outfeed = Cast(instruction); TF_RETURN_IF_ERROR(CheckIsTokenOperand(instruction, 1)); @@ -171,9 +201,8 @@ Status ShapeVerifier::HandleOutfeed(HloInstruction* instruction) { return InternalError( "Expected outfeed shape to be equal to operand's shape %s, " "actual shape is %s:\n%s", - StringifyShape(outfeed->operand(0)->shape()).c_str(), - StringifyShape(outfeed->outfeed_shape()).c_str(), - outfeed->ToString().c_str()); + StringifyShape(outfeed->operand(0)->shape()), + StringifyShape(outfeed->outfeed_shape()), outfeed->ToString()); } return CheckShape(outfeed, ShapeUtil::MakeTokenShape()); } @@ -189,24 +218,21 @@ bool ShapeVerifier::HasCompatibleElementTypes(const Shape& shape_0, } Status ShapeVerifier::HandleRng(HloInstruction* instruction) { - if (instruction->operand_count() != 2) { - return InternalError("Expected two operands for Rng instruction: %s", - instruction->ToString().c_str()); - } + TF_RETURN_IF_ERROR(CheckOperandCount(instruction, 2)); const Shape& shape_0 = instruction->operand(0)->shape(); const Shape& shape_1 = instruction->operand(1)->shape(); if (!ShapeUtil::IsScalar(shape_0) || !ShapeUtil::IsScalar(shape_1)) { return InternalError( "Expected scalar types for the two operands of Rng instruction: %s", - instruction->ToString().c_str()); + instruction->ToString()); } if (!HasCompatibleElementTypes(shape_0, shape_1, instruction->shape())) { return InternalError( "Expected compatible element types for the result and the two operands" " of Rng instruction: %s", - instruction->ToString().c_str()); + instruction->ToString()); } PrimitiveType element_type = shape_0.element_type(); @@ -219,7 +245,7 @@ Status ShapeVerifier::HandleRng(HloInstruction* instruction) { "Element type not supported." " Expected element to be of floating point type, integral type or" " predicate type for RngUniform: %s", - instruction->ToString().c_str()); + instruction->ToString()); } break; @@ -228,48 +254,70 @@ Status ShapeVerifier::HandleRng(HloInstruction* instruction) { return InternalError( "Element type not supported." " Expected element to be FloatingPointType for RngNormal: %s", - instruction->ToString().c_str()); + instruction->ToString()); } break; default: return InternalError( "Invalid Rng distribution %s", - RandomDistribution_Name(instruction->random_distribution()).c_str()); + RandomDistribution_Name(instruction->random_distribution())); } return Status::OK(); } Status ShapeVerifier::HandleReverse(HloInstruction* reverse) { + TF_RETURN_IF_ERROR(CheckOperandCount(reverse, 1)); return CheckShape( reverse, ShapeInference::InferReverseShape(reverse->operand(0)->shape(), reverse->dimensions())); } Status ShapeVerifier::HandleSort(HloInstruction* sort) { - if (sort->operand_count() == 2 && - !ShapeUtil::SameDimensions(sort->operand(0)->shape(), - sort->operand(1)->shape())) { - return InternalError( - "Expected sort to have to have the same dimensions for the keys and " - "the values. Keys shape is: %s\n, Values shape is: %s", - StringifyShape(sort->operand(0)->shape()).c_str(), - StringifyShape(sort->operand(1)->shape()).c_str()); + if (sort->operand_count() < 1) { + return InternalError("Expected at least 1 operand for %s instruction: %s", + HloOpcodeString(sort->opcode()), sort->ToString()); + } + for (int64 operand = 1; operand < sort->operand_count(); ++operand) { + if (!ShapeUtil::SameDimensions(sort->operand(0)->shape(), + sort->operand(operand)->shape())) { + return InternalError( + "Expected sort to have to have the same dimensions for the keys " + "and the values. Keys shape is: %s\n, Values shape (operand index " + "%lld) is: %s", + StringifyShape(sort->operand(0)->shape()), operand, + StringifyShape(sort->operand(operand)->shape())); + } } return CheckVariadicShape(sort); } Status ShapeVerifier::HandleConstant(HloInstruction* constant) { + TF_RETURN_IF_ERROR(CheckOperandCount(constant, 0)); + if (!Cast(constant)->HasLiteral()) { + return InternalError("Constant is required to have a valid literal: %s", + constant->ToString()); + } return CheckShape(constant, constant->literal().shape()); } -Status ShapeVerifier::HandleIota(HloInstruction* iota) { - return ShapeUtil::Rank(iota->shape()) == 1 - ? Status::OK() - : InternalError("Iota only supports arrays of rank 1."); +Status ShapeVerifier::HandleIota(HloInstruction* instruction) { + TF_RETURN_IF_ERROR(CheckOperandCount(instruction, 0)); + auto* iota = Cast(instruction); + const int64 rank = ShapeUtil::Rank(iota->shape()); + if (rank == 0) { + return InternalError("Iota does not support scalars."); + } + int64 iota_dimension = iota->iota_dimension(); + if (iota_dimension >= rank) { + return InternalError( + "The iota dimension cannot go beyond the operation rank."); + } + return Status::OK(); } Status ShapeVerifier::HandleGetTupleElement(HloInstruction* get_tuple_element) { + TF_RETURN_IF_ERROR(CheckOperandCount(get_tuple_element, 1)); return CheckShape(get_tuple_element, ShapeInference::InferGetTupleElementShape( get_tuple_element->operand(0)->shape(), @@ -277,21 +325,28 @@ Status ShapeVerifier::HandleGetTupleElement(HloInstruction* get_tuple_element) { } Status ShapeVerifier::HandleReduce(HloInstruction* reduce) { - if (!ShapeUtil::IsArray(reduce->shape())) { - return InvalidArgument("Variadic reduce is not supported."); + if (reduce->operand_count() % 2 != 0) { + return InternalError( + "Expected an even number of operands for %s instruction: %s", + HloOpcodeString(reduce->opcode()), reduce->ToString()); } - return CheckShape( - reduce, - ShapeInference::InferReduceShape( - {&reduce->operand(0)->shape(), &reduce->operand(1)->shape()}, - reduce->dimensions(), reduce->to_apply()->ComputeProgramShape())); + + std::vector operand_shapes; + for (const HloInstruction* operand : reduce->operands()) { + operand_shapes.push_back(&operand->shape()); + } + return CheckShape(reduce, ShapeInference::InferReduceShape( + operand_shapes, reduce->dimensions(), + reduce->to_apply()->ComputeProgramShape())); } Status ShapeVerifier::HandleBitcast(HloInstruction* bitcast) { + TF_RETURN_IF_ERROR(CheckOperandCount(bitcast, 1)); return Status::OK(); } Status ShapeVerifier::HandleBroadcast(HloInstruction* broadcast) { + TF_RETURN_IF_ERROR(CheckOperandCount(broadcast, 1)); // HLO broadcast has no exact analog at the proto level so there is no // ShapeInference method. Check the output shape explicitly. const Shape& operand_shape = broadcast->operand(0)->shape(); @@ -303,14 +358,16 @@ Status ShapeVerifier::HandleBroadcast(HloInstruction* broadcast) { operand_dimension < ShapeUtil::Rank(operand_shape); ++operand_dimension) { int64 output_dimension = broadcast->dimensions()[operand_dimension]; - TF_RET_CHECK(broadcast->shape().dimensions(output_dimension) == - operand_shape.dimensions(operand_dimension)) + TF_RET_CHECK((output_dimension < ShapeUtil::Rank(broadcast->shape())) && + (broadcast->shape().dimensions(output_dimension) == + operand_shape.dimensions(operand_dimension))) << broadcast->ToString() << " operand shape " << operand_shape; } return Status::OK(); } Status ShapeVerifier::HandleReshape(HloInstruction* reshape) { + TF_RETURN_IF_ERROR(CheckOperandCount(reshape, 1)); // Check for mixed precision. TF_RETURN_IF_ERROR(CheckShape(reshape, reshape->shape())); TF_RET_CHECK(ShapeUtil::ElementsIn(reshape->shape()) == @@ -319,12 +376,14 @@ Status ShapeVerifier::HandleReshape(HloInstruction* reshape) { } Status ShapeVerifier::HandleTranspose(HloInstruction* transpose) { + TF_RETURN_IF_ERROR(CheckOperandCount(transpose, 1)); return CheckShape( transpose, ShapeInference::InferTransposeShape( transpose->operand(0)->shape(), transpose->dimensions())); } Status ShapeVerifier::HandleParameter(HloInstruction* hlo) { + TF_RETURN_IF_ERROR(CheckOperandCount(hlo, 0)); return Status::OK(); } @@ -333,7 +392,7 @@ Status ShapeVerifier::HandleFusion(HloInstruction* fusion) { int64 param_no = fused_param->parameter_number(); if (!ShapesSame(fused_param->shape(), fusion->operand(param_no)->shape())) { return InternalError( - "Shape mismatch between parameter number %lld and its operand in " + "Shape mismatch between parameter number %d and its operand in " "%s.", param_no, fusion->ToString().c_str()); } @@ -349,9 +408,30 @@ Status ShapeVerifier::HandleCall(HloInstruction* call) { return CheckShape(call, call->to_apply()->root_instruction()->shape()); } -Status ShapeVerifier::HandleCustomCall(HloInstruction*) { return Status::OK(); } +Status ShapeVerifier::HandleCustomCall(HloInstruction* instruction) { + const HloCustomCallInstruction* custom_call = + DynCast(instruction); + TF_RET_CHECK(custom_call != nullptr); + if (custom_call->layout_constrained()) { + // If the layout is constrained, verify all the respective shapes have + // layouts and that the constrained operand shapes match the shapes of the + // operands. + TF_RET_CHECK(LayoutUtil::HasLayout(custom_call->shape())); + TF_RET_CHECK(custom_call->operand_count() == + custom_call->operand_shapes_with_layout().size()); + for (int64 i = 0; i < custom_call->operand_count(); ++i) { + const Shape& operand_shape_with_layout = + custom_call->operand_shapes_with_layout()[i]; + TF_RET_CHECK(ShapeUtil::Compatible(custom_call->operand(i)->shape(), + operand_shape_with_layout)); + TF_RET_CHECK(LayoutUtil::HasLayout(operand_shape_with_layout)); + } + } + return Status::OK(); +} Status ShapeVerifier::HandleSlice(HloInstruction* slice) { + TF_RETURN_IF_ERROR(CheckOperandCount(slice, 1)); return CheckShape(slice, ShapeInference::InferSliceShape( slice->operand(0)->shape(), slice->slice_starts(), @@ -359,6 +439,7 @@ Status ShapeVerifier::HandleSlice(HloInstruction* slice) { } Status ShapeVerifier::HandleDynamicSlice(HloInstruction* dynamic_slice) { + TF_RETURN_IF_ERROR(CheckOperandCount(dynamic_slice, 2)); return CheckShape(dynamic_slice, ShapeInference::InferDynamicSliceShape( dynamic_slice->operand(0)->shape(), dynamic_slice->operand(1)->shape(), @@ -367,6 +448,7 @@ Status ShapeVerifier::HandleDynamicSlice(HloInstruction* dynamic_slice) { Status ShapeVerifier::HandleDynamicUpdateSlice( HloInstruction* dynamic_update_slice) { + TF_RETURN_IF_ERROR(CheckOperandCount(dynamic_update_slice, 3)); return CheckShape(dynamic_update_slice, ShapeInference::InferDynamicUpdateSliceShape( dynamic_update_slice->operand(0)->shape(), @@ -396,6 +478,7 @@ Status ShapeVerifier::HandleMap(HloInstruction* map) { } Status ShapeVerifier::HandleReduceWindow(HloInstruction* reduce_window) { + TF_RETURN_IF_ERROR(CheckOperandCount(reduce_window, 2)); return CheckShape( reduce_window, ShapeInference::InferReduceWindowShape( @@ -405,6 +488,7 @@ Status ShapeVerifier::HandleReduceWindow(HloInstruction* reduce_window) { } Status ShapeVerifier::HandleSelectAndScatter(HloInstruction* instruction) { + TF_RETURN_IF_ERROR(CheckOperandCount(instruction, 3)); return CheckShape( instruction, ShapeInference::InferSelectAndScatterShape( @@ -415,6 +499,7 @@ Status ShapeVerifier::HandleSelectAndScatter(HloInstruction* instruction) { } Status ShapeVerifier::HandleWhile(HloInstruction* xla_while) { + TF_RETURN_IF_ERROR(CheckOperandCount(xla_while, 1)); TF_RETURN_IF_ERROR( CheckOperandAndParameter(xla_while, 0, xla_while->while_body(), 0)); TF_RETURN_IF_ERROR( @@ -425,7 +510,7 @@ Status ShapeVerifier::HandleWhile(HloInstruction* xla_while) { return InternalError( "Conditional computation shape does not lead to a scalar predicate " "shape: %s", - StringifyShape(conditional_shape).c_str()); + StringifyShape(conditional_shape)); } // The shape of kWhile should match the shape of the body computation it // calls. @@ -434,6 +519,7 @@ Status ShapeVerifier::HandleWhile(HloInstruction* xla_while) { } Status ShapeVerifier::HandleConditional(HloInstruction* conditional) { + TF_RETURN_IF_ERROR(CheckOperandCount(conditional, 3)); TF_RETURN_IF_ERROR(CheckOperandAndParameter( conditional, 1, conditional->true_computation(), 0)); TF_RETURN_IF_ERROR(CheckOperandAndParameter( @@ -448,12 +534,14 @@ Status ShapeVerifier::HandleConditional(HloInstruction* conditional) { } Status ShapeVerifier::HandlePad(HloInstruction* pad) { + TF_RETURN_IF_ERROR(CheckOperandCount(pad, 2)); return CheckShape(pad, ShapeInference::InferPadShape(pad->operand(0)->shape(), pad->operand(1)->shape(), pad->padding_config())); } Status ShapeVerifier::HandleSend(HloInstruction* send) { + TF_RETURN_IF_ERROR(CheckOperandCount(send, 2)); return CheckShape(send, ShapeUtil::MakeTupleShape({send->operand(0)->shape(), ShapeUtil::MakeShape(U32, {}), @@ -461,10 +549,12 @@ Status ShapeVerifier::HandleSend(HloInstruction* send) { } Status ShapeVerifier::HandleSendDone(HloInstruction* send_done) { + TF_RETURN_IF_ERROR(CheckOperandCount(send_done, 1)); return CheckShape(send_done, ShapeUtil::MakeTokenShape()); } Status ShapeVerifier::HandleRecv(HloInstruction* recv) { + TF_RETURN_IF_ERROR(CheckOperandCount(recv, 1)); return CheckShape( recv, ShapeUtil::MakeTupleShape( {ShapeUtil::GetTupleElementShape(recv->shape(), 0), @@ -472,6 +562,7 @@ Status ShapeVerifier::HandleRecv(HloInstruction* recv) { } Status ShapeVerifier::HandleRecvDone(HloInstruction* recv_done) { + TF_RETURN_IF_ERROR(CheckOperandCount(recv_done, 1)); return CheckShape( recv_done, ShapeUtil::MakeTupleShape( @@ -481,6 +572,7 @@ Status ShapeVerifier::HandleRecvDone(HloInstruction* recv_done) { Status ShapeVerifier::HandleBatchNormTraining( HloInstruction* batch_norm_training) { + TF_RETURN_IF_ERROR(CheckOperandCount(batch_norm_training, 3)); return CheckShape(batch_norm_training, ShapeInference::InferBatchNormTrainingShape( batch_norm_training->operand(0)->shape(), @@ -491,6 +583,7 @@ Status ShapeVerifier::HandleBatchNormTraining( Status ShapeVerifier::HandleBatchNormInference( HloInstruction* batch_norm_inference) { + TF_RETURN_IF_ERROR(CheckOperandCount(batch_norm_inference, 5)); return CheckShape(batch_norm_inference, ShapeInference::InferBatchNormInferenceShape( batch_norm_inference->operand(0)->shape(), @@ -502,6 +595,7 @@ Status ShapeVerifier::HandleBatchNormInference( } Status ShapeVerifier::HandleBatchNormGrad(HloInstruction* batch_norm_grad) { + TF_RETURN_IF_ERROR(CheckOperandCount(batch_norm_grad, 5)); return CheckShape(batch_norm_grad, ShapeInference::InferBatchNormGradShape( batch_norm_grad->operand(0)->shape(), batch_norm_grad->operand(1)->shape(), @@ -538,6 +632,7 @@ Status CheckMixedPrecisionOperands(const HloInstruction* instruction) { case HloOpcode::kTupleSelect: case HloOpcode::kSend: case HloOpcode::kSendDone: + case HloOpcode::kSort: case HloOpcode::kTuple: case HloOpcode::kWhile: break; @@ -556,7 +651,7 @@ Status CheckMixedPrecisionOperands(const HloInstruction* instruction) { return InternalError( "Seen floating point types of different precisions in " "%s, but mixed precision is disallowed.", - instruction->ToString().c_str()); + instruction->ToString()); } return Status::OK(); })); @@ -569,6 +664,7 @@ Status CheckMixedPrecisionOperands(const HloInstruction* instruction) { } // namespace Status ShapeVerifier::HandleGather(HloInstruction* gather) { + TF_RETURN_IF_ERROR(CheckOperandCount(gather, 2)); return CheckShape( gather, ShapeInference::InferGatherShape( @@ -577,6 +673,7 @@ Status ShapeVerifier::HandleGather(HloInstruction* gather) { } Status ShapeVerifier::HandleScatter(HloInstruction* scatter) { + TF_RETURN_IF_ERROR(CheckOperandCount(scatter, 3)); return CheckShape( scatter, ShapeInference::InferScatterShape( scatter->operand(0)->shape(), scatter->operand(1)->shape(), @@ -646,9 +743,8 @@ Status ShapeVerifier::CheckShape(const HloInstruction* instruction, return InternalError( "Expected instruction to have shape equal to %s, actual " "shape is %s:\n%s", - StringifyShape(inferred_shape).c_str(), - StringifyShape(instruction->shape()).c_str(), - instruction->ToString().c_str()); + StringifyShape(inferred_shape), StringifyShape(instruction->shape()), + instruction->ToString()); } return Status::OK(); } @@ -665,12 +761,14 @@ Status ShapeVerifier::CheckShape(const HloInstruction* instruction, } Status ShapeVerifier::CheckUnaryShape(const HloInstruction* instruction) { + TF_RETURN_IF_ERROR(CheckOperandCount(instruction, 1)); return CheckShape(instruction, ShapeInference::InferUnaryOpShape(instruction->opcode(), instruction->operand(0))); } Status ShapeVerifier::CheckBinaryShape(const HloInstruction* instruction) { + TF_RETURN_IF_ERROR(CheckOperandCount(instruction, 2)); return CheckShape( instruction, ShapeInference::InferBinaryOpShape(instruction->opcode(), instruction->operand(0), @@ -678,6 +776,7 @@ Status ShapeVerifier::CheckBinaryShape(const HloInstruction* instruction) { } Status ShapeVerifier::CheckTernaryShape(const HloInstruction* instruction) { + TF_RETURN_IF_ERROR(CheckOperandCount(instruction, 3)); return CheckShape(instruction, ShapeInference::InferTernaryOpShape( instruction->opcode(), instruction->operand(0), @@ -690,8 +789,7 @@ Status ShapeVerifier::CheckVariadicShape(const HloInstruction* instruction) { instruction->opcode(), instruction->operands())); } -string ComputationsToString( - tensorflow::gtl::ArraySlice computations) { +string ComputationsToString(absl::Span computations) { return absl::StrJoin(computations, ",", [](string* s, const HloComputation* computation) { s->append(computation->name()); @@ -713,23 +811,23 @@ Status VerifyHloStructure(HloModule* module) { for (const HloComputation* computation : module->computations()) { if (computation->parent() == nullptr) { return InternalError("Computation %s has a null parent pointer", - computation->name().c_str()); + computation->name()); } if (computation->parent() != module) { return InternalError( "Computation %s parent() does not point to parent module", - computation->name().c_str()); + computation->name()); } for (const HloInstruction* instruction : computation->instructions()) { if (instruction->parent() == nullptr) { return InternalError("Instruction %s has a null parent pointer", - instruction->name().c_str()); + instruction->name()); } if (instruction->parent() != computation) { return InternalError( "Instruction %s parent() does not point to parent computation", - instruction->name().c_str()); + instruction->name()); } } } @@ -746,9 +844,8 @@ Status VerifyHloStructure(HloModule* module) { return InternalError( "Operand %d (%s) of instruction %s is in a different " "computation: %s vs %s", - i, operand->name().c_str(), instruction->name().c_str(), - operand->parent()->name().c_str(), - instruction->parent()->name().c_str()); + i, operand->name(), instruction->name(), + operand->parent()->name(), instruction->parent()->name()); } } } @@ -756,7 +853,186 @@ Status VerifyHloStructure(HloModule* module) { return Status::OK(); } -Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const { +namespace { + +// Returns true if the given Shape has a TOKEN shape as any subshape. +bool ShapeContainsToken(const Shape& shape) { + bool contains_token = false; + ShapeUtil::ForEachSubshape( + shape, [&contains_token](const Shape& subshape, const ShapeIndex&) { + if (ShapeUtil::IsToken(subshape)) { + contains_token = true; + } + }); + return contains_token; +} + +// Verifies that all types entering and exiting the entry computation are +// legal. +Status VerifyEntryAndExitShapes(const HloModule& module) { + // Tokens cannot be passed as entry parameters. + // TODO(b/80000000): Remove this constraint. + for (int i = 0; i < module.entry_computation()->num_parameters(); ++i) { + HloInstruction* param = + module.entry_computation()->parameter_instruction(i); + if (ShapeContainsToken(param->shape())) { + return InternalError( + "Entry parameter %d is or contains a token shape: %s", i, + ShapeUtil::HumanString(param->shape())); + } + } + return Status::OK(); +} + +// Verifies that entry computation layout matches characteristics of +// entry computation. +Status CheckEntryComputationLayout(const HloModule& module) { + const HloComputation* computation = module.entry_computation(); + const auto& layout = module.entry_computation_layout(); + const ShapeLayout& result_layout = layout.result_layout(); + + if (LayoutUtil::IsSparseArray(result_layout.shape())) { + return Unimplemented( + "Sparse arrays are not yet fully supported in program result shape: %s", + ShapeUtil::HumanStringWithLayout(result_layout.shape())); + } + + if (!ShapeUtil::Compatible(computation->root_instruction()->shape(), + result_layout.shape())) { + return InternalError( + "Shape of the root instruction of entry computation (%s) should be " + "compatible to one specified in module's entry computation layout (%s)", + ShapeUtil::HumanString(computation->root_instruction()->shape()), + ShapeUtil::HumanString(result_layout.shape())); + } + + if (computation->num_parameters() != layout.parameter_count()) { + return InternalError( + "Number of parameters in entry computation layout (%d) must be same " + "as number of parameters of entry computation computation (%d)", + layout.parameter_count(), computation->num_parameters()); + } + + for (int i = 0; i < computation->num_parameters(); ++i) { + const HloInstruction* parameter = computation->parameter_instruction(i); + if (LayoutUtil::IsSparseArray(layout.parameter_shape(i))) { + return Unimplemented( + "Sparse arrays are not yet fully supported " + "in program parameter shape: %s", + ShapeUtil::HumanStringWithLayout(layout.parameter_shape(i))); + } + if (!ShapeUtil::Compatible(parameter->shape(), layout.parameter_shape(i))) { + return InternalError( + "Shape of the entry computation parameter %d is %s should be " + "compatible to the one specified in module's entry computation " + "layout %s", + i, ShapeUtil::HumanString(parameter->shape()), + ShapeUtil::HumanString(layout.parameter_shape(i))); + } + } + + return Status::OK(); +} + +// Checks if the given two instructions share the same channel id. +Status CheckSameChannel(const HloInstruction* instr1, + const HloInstruction* instr2) { + if (instr1->channel_id() != instr2->channel_id()) { + return InternalError( + "Expected to have the same channel id, actual channel ids are: %s " + "(%d), %s (%d)", + instr1->ToString(), instr1->channel_id(), instr2->ToString(), + instr2->channel_id()); + } + return Status::OK(); +} + +// Checks if the given two instructions have the same is_host_transfer +// attribute value. Intsructions must be send/recv instructions or their +// 'done' variant. +Status CheckSameIsHostTransfer(const HloInstruction* instr1, + const HloInstruction* instr2) { + const HloSendRecvInstruction* send_recv1 = + DynCast(instr1); + const HloSendRecvInstruction* send_recv2 = + DynCast(instr2); + TF_RET_CHECK(send_recv1 != nullptr); + TF_RET_CHECK(send_recv2 != nullptr); + if (send_recv1->is_host_transfer() != send_recv2->is_host_transfer()) { + return InternalError( + "Expected instructions to have the same is-host-transfer property: " + "%s, " + "%s ", + instr1->ToString(), instr2->ToString()); + } + return Status::OK(); +} + +// Checks various invariants of send and recv instructions. +Status VerifySendsAndRecvs(const HloModule& module) { + absl::flat_hash_map host_channels; + // Host send/recv instructions must have their own unique channel. + auto check_unique_host_channel = [&](const HloInstruction* instruction) { + const HloSendRecvInstruction* sendrecv = + DynCast(instruction); + if (sendrecv->is_host_transfer()) { + auto it_inserted = + host_channels.insert({sendrecv->channel_id(), sendrecv}); + if (!it_inserted.second) { + return FailedPrecondition( + "Channel %d is used for multiple host send/recv instructions: " + "%s " + "and " + "%s", + sendrecv->channel_id(), sendrecv->ToString(), + it_inserted.first->second->ToString()); + } + } + + return Status::OK(); + }; + + // Send/Recv instruction must have a single user: the corresponding + // SendDone/RecvDone. with matching channel. + for (const HloComputation* computation : module.computations()) { + for (const HloInstruction* instruction : computation->instructions()) { + switch (instruction->opcode()) { + case HloOpcode::kSend: { + TF_RETURN_IF_ERROR(check_unique_host_channel(instruction)); + TF_RET_CHECK(instruction->users().size() == 1); + const HloInstruction* send_done = instruction->users().front(); + TF_RET_CHECK(send_done->opcode() == HloOpcode::kSendDone); + TF_RETURN_IF_ERROR(CheckSameChannel(instruction, send_done)); + TF_RETURN_IF_ERROR(CheckSameIsHostTransfer(instruction, send_done)); + break; + } + case HloOpcode::kRecv: { + TF_RETURN_IF_ERROR(check_unique_host_channel(instruction)); + TF_RET_CHECK(instruction->users().size() == 1); + const HloInstruction* recv_done = instruction->users().front(); + TF_RET_CHECK(recv_done->opcode() == HloOpcode::kRecvDone); + TF_RETURN_IF_ERROR(CheckSameChannel(instruction, recv_done)); + TF_RETURN_IF_ERROR(CheckSameIsHostTransfer(instruction, recv_done)); + break; + } + case HloOpcode::kSendDone: + TF_RET_CHECK(instruction->operands().size() == 1); + TF_RET_CHECK(instruction->operand(0)->opcode() == HloOpcode::kSend); + break; + case HloOpcode::kRecvDone: + TF_RET_CHECK(instruction->operands().size() == 1); + TF_RET_CHECK(instruction->operand(0)->opcode() == HloOpcode::kRecv); + break; + default: + break; + } + } + } + return Status::OK(); +} + +// CHECKs various invariants of a fusion instruction. +Status CheckFusionInstruction(HloInstruction* fusion) { // The parent fusion instruction of the fusion computation must be 'fusion'. HloComputation* fused_computation = fusion->fused_instructions_computation(); if (fusion != fused_computation->FusionInstruction()) { @@ -764,7 +1040,7 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const { "Instruction of fused computation does not match expected " "instruction " "%s.", - fusion->ToString().c_str()); + fusion->ToString()); } // Fused root instruction and fused parameters must all be owned by the @@ -778,7 +1054,7 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const { if (fused_root == instruction) { if (root_owned) { return InternalError("Root appears more than once in %s.", - fusion->ToString().c_str()); + fusion->ToString()); } root_owned = true; } @@ -786,7 +1062,7 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const { if (fused_parameters[i] == instruction) { if (parameter_owned[i]) { return InternalError("Parameter appears more than once in %s.", - fusion->ToString().c_str()); + fusion->ToString()); } parameter_owned[i] = true; } @@ -794,20 +1070,19 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const { } if (!root_owned) { return InternalError("Root not found in computation of %s.", - fusion->ToString().c_str()); + fusion->ToString()); } // Make sure all the parameter_owned entries are set for (int i = 0; i < parameter_owned.size(); i++) { if (!parameter_owned[i]) { return InternalError("Parameter %d not found in computation of %s.", i, - fusion->ToString().c_str()); + fusion->ToString()); } } // Fused root must have no users. if (fused_root->user_count() != 0) { - return InternalError("Root of %s may not have users.", - fusion->ToString().c_str()); + return InternalError("Root of %s may not have users.", fusion->ToString()); } // All uses of fused instructions must be in the fusion computation, and @@ -817,14 +1092,13 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const { if (instruction != fused_root) { if (instruction->user_count() == 0) { return InternalError("Non-root instruction %s in %s must have users.", - instruction->ToString().c_str(), - fusion->ToString().c_str()); + instruction->ToString(), fusion->ToString()); } for (auto& user : instruction->users()) { if (fused_computation != user->parent()) { return InternalError( "Non-root instruction %s in %s may not have external users.", - instruction->ToString().c_str(), fusion->ToString().c_str()); + instruction->ToString(), fusion->ToString()); } } } @@ -837,19 +1111,19 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const { for (auto fused_param : fused_parameters) { int64 param_no = fused_param->parameter_number(); if (param_no < 0) { - return InternalError("Unexpected negative parameter number %lld in %s.", - param_no, fusion->ToString().c_str()); + return InternalError("Unexpected negative parameter number %d in %s.", + param_no, fusion->ToString()); } if (param_no >= fused_parameters.size()) { return InternalError( - "Unexpected parameter number %lld in %s: higher then number of " + "Unexpected parameter number %d in %s: higher then number of " "parameters %lu.", - param_no, fusion->ToString().c_str(), fused_parameters.size()); + param_no, fusion->ToString(), fused_parameters.size()); } if (parameter_numbers[param_no]) { return InternalError( - "Did not expect parameter number %lld more than once in %s.", - param_no, fusion->ToString().c_str()); + "Did not expect parameter number %d more than once in %s.", param_no, + fusion->ToString()); } parameter_numbers[param_no] = true; } @@ -857,56 +1131,36 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const { for (int i = 0; i < parameter_numbers.size(); i++) { if (!parameter_numbers[i]) { return InternalError("Did not see parameter number %d in %s.", i, - fusion->ToString().c_str()); + fusion->ToString()); } } + TF_RET_CHECK(fusion->called_computations() == + absl::Span( + {fusion->fused_instructions_computation()})) + << "Fusion HLO calls computations other than the " + "fused_instructions_computation: " + << fusion->ToString() << " fusion->fused_instructions_computation(): " + << fusion->fused_instructions_computation()->ToString() + << " fusion->called_computations(): " + << ComputationsToString(fusion->called_computations()); + + for (const auto& fused : fusion->fused_instructions()) { + TF_RET_CHECK(fused->parent() == fusion->fused_instructions_computation()) + << "Fused HLO was missing a parent: " << fused->ToString() + << " parent: " << fused->parent() + << " computation: " << fusion->parent(); + } + // TODO(b/65423525): We'd like to check that all operands are distinct. // This is currently disabled due to the invariant being violated by // multi-output fusion. return Status::OK(); } -Status HloVerifier::CheckWhileInstruction(HloInstruction* instruction) { - auto* while_cond = instruction->while_condition(); - auto* while_body = instruction->while_body(); - if (while_cond->num_parameters() != 1) { - return FailedPrecondition( - "While condition must have exactly 1 parameter; had %lld : %s", - while_cond->num_parameters(), while_cond->ToString().c_str()); - } - if (while_body->num_parameters() != 1) { - return FailedPrecondition( - "While body must have exactly 1 parameter; had %lld : %s", - while_body->num_parameters(), while_body->ToString().c_str()); - } - if (instruction->operand_count() != 1) { - return FailedPrecondition( - "While loop must have exactly one operand; had %lld : %s", - instruction->operand_count(), instruction->ToString().c_str()); - } - return Status::OK(); -} - -Status HloVerifier::CheckConditionalInstruction(HloInstruction* instruction) { - if (instruction->true_computation()->num_parameters() != 1) { - return FailedPrecondition( - "True computation %s of %s must have 1 parameter insted of %lld", - instruction->true_computation()->name().c_str(), - instruction->ToString().c_str(), - instruction->true_computation()->num_parameters()); - } - if (instruction->false_computation()->num_parameters() != 1) { - return FailedPrecondition( - "False computation %s of %s must have 1 parameter insted of %lld", - instruction->false_computation()->name().c_str(), - instruction->ToString().c_str(), - instruction->false_computation()->num_parameters()); - } - return Status::OK(); -} - -Status HloVerifier::CheckElementwiseInstruction(HloInstruction* instruction) { +// Checks that the non-scalar operand shapes are compatible to the output +// shape, i.e., that there are no implicit broadcasts of size-one dimensions. +Status CheckElementwiseInstruction(HloInstruction* instruction) { const Shape& out_shape = instruction->shape(); for (HloInstruction* operand : instruction->operands()) { const Shape& operand_shape = operand->shape(); @@ -915,211 +1169,178 @@ Status HloVerifier::CheckElementwiseInstruction(HloInstruction* instruction) { "Implicit broadcast is not allowed in HLO." "Found different shapes for instruction %s.\n" "output: %s\noperand: %s\n", - HloOpcodeString(instruction->opcode()).c_str(), - ShapeUtil::HumanString(out_shape).c_str(), - ShapeUtil::HumanString(operand_shape).c_str()); + HloOpcodeString(instruction->opcode()), + ShapeUtil::HumanString(out_shape), + ShapeUtil::HumanString(operand_shape)); } } return Status::OK(); } -namespace { +// Visitor which verifies various fields on the HLO instruction. This class does +// not check result shape as that is checked in the ShapeVerifier. +class InstructionVerifier : public DfsHloVisitorWithDefault { + public: + explicit InstructionVerifier(std::function + instruction_can_change_layout_func) + : instruction_can_change_layout_func_( + instruction_can_change_layout_func) {} -// Returns true if the given Shape has a TOKEN shape as any subshape. -bool ShapeContainsToken(const Shape& shape) { - bool contains_token = false; - ShapeUtil::ForEachSubshape( - shape, [&contains_token](const Shape& subshape, const ShapeIndex&) { - if (ShapeUtil::IsToken(subshape)) { - contains_token = true; - } - }); - return contains_token; -} + Status DefaultAction(HloInstruction*) override { return Status::OK(); } -// Verifies that all types entering and exiting the entry computation are -// legal. -Status VerifyEntryAndExitShapes(const HloModule& module) { - // Tokens cannot be passed as entry parameters. - // TODO(b/80000000): Remove this constraint. - for (int i = 0; i < module.entry_computation()->num_parameters(); ++i) { - HloInstruction* param = - module.entry_computation()->parameter_instruction(i); - if (ShapeContainsToken(param->shape())) { - return InternalError( - "Entry parameter %d is or contains a token shape: %s", i, - ShapeUtil::HumanString(param->shape()).c_str()); - } + Status HandleFusion(HloInstruction* fusion) override { + return CheckFusionInstruction(fusion); } - return Status::OK(); -} -// Checks if the given two instructions share the same channel id. -Status CheckSameChannel(const HloInstruction* instr1, - const HloInstruction* instr2) { - if (instr1->channel_id() != instr2->channel_id()) { - return InternalError( - "Expected to have the same channel id, actual channel ids are: %s " - "(%lld), %s (%lld)", - instr1->ToString().c_str(), instr1->channel_id(), - instr2->ToString().c_str(), instr2->channel_id()); + Status HandleBroadcast(HloInstruction* broadcast) override { + // If you see this failure then someone has confused the difference + // between the HLO broadcast op, and the UserComputation broadcast + // op. See https://groups.google.com/forum/#!topic/xla-dev/9LqijHmTt_I + // or ComputationLowerer::Visit() + TF_RET_CHECK(broadcast->dimensions().size() == + ShapeUtil::Rank(broadcast->operand(0)->shape())) + << "Broadcast HLO (" << broadcast->ToShortString() + << ") has invalid number of dimensions: " + << broadcast->dimensions().size() + << " != " << ShapeUtil::Rank(broadcast->operand(0)->shape()); + return Status::OK(); } - return Status::OK(); -} -// Checks if the given two instructions have the same is_host_transfer -// attribute value. Intsructions must be send/recv instructions or their -// 'done' variant. -Status CheckSameIsHostTransfer(const HloInstruction* instr1, - const HloInstruction* instr2) { - const HloSendRecvInstruction* send_recv1 = - DynCast(instr1); - const HloSendRecvInstruction* send_recv2 = - DynCast(instr2); - TF_RET_CHECK(send_recv1 != nullptr); - TF_RET_CHECK(send_recv2 != nullptr); - if (send_recv1->is_host_transfer() != send_recv2->is_host_transfer()) { - return InternalError( - "Expected instructions to have the same is-host-transfer property: " - "%s, " - "%s ", - instr1->ToString().c_str(), instr2->ToString().c_str()); + Status HandleWhile(HloInstruction* xla_while) override { + auto* while_cond = xla_while->while_condition(); + auto* while_body = xla_while->while_body(); + if (while_cond->num_parameters() != 1) { + return FailedPrecondition( + "While condition must have exactly 1 parameter; had %d : %s", + while_cond->num_parameters(), while_cond->ToString()); + } + if (while_body->num_parameters() != 1) { + return FailedPrecondition( + "While body must have exactly 1 parameter; had %d : %s", + while_body->num_parameters(), while_body->ToString()); + } + if (xla_while->operand_count() != 1) { + return FailedPrecondition( + "While loop must have exactly one operand; had %d : %s", + xla_while->operand_count(), xla_while->ToString()); + } + return Status::OK(); } - return Status::OK(); -} -// Checks various invariants of send and recv instructions. -Status VerifySendsAndRecvs(const HloModule& module) { - tensorflow::gtl::FlatMap host_channels; - // Host send/recv instructions must have their own unique channel. - auto check_unique_host_channel = [&](const HloInstruction* instruction) { - const HloSendRecvInstruction* sendrecv = - DynCast(instruction); - if (sendrecv->is_host_transfer()) { - auto it_inserted = - host_channels.insert({sendrecv->channel_id(), sendrecv}); - if (!it_inserted.second) { - return FailedPrecondition( - "Channel %lld is used for multiple host send/recv instructions: " - "%s " - "and " - "%s", - sendrecv->channel_id(), sendrecv->ToString().c_str(), - it_inserted.first->second->ToString().c_str()); - } + Status HandleConditional(HloInstruction* conditional) override { + if (conditional->true_computation()->num_parameters() != 1) { + return FailedPrecondition( + "True computation %s of %s must have 1 parameter insted of %d", + conditional->true_computation()->name(), conditional->ToString(), + conditional->true_computation()->num_parameters()); + } + if (conditional->false_computation()->num_parameters() != 1) { + return FailedPrecondition( + "False computation %s of %s must have 1 parameter insted of %d", + conditional->false_computation()->name(), conditional->ToString(), + conditional->false_computation()->num_parameters()); } + return Status::OK(); + } + Status HandleElementwiseUnary(HloInstruction* instruction) override { + return CheckElementwiseInstruction(instruction); + } + + Status HandleElementwiseBinary(HloInstruction* instruction) override { + return CheckElementwiseInstruction(instruction); + } + + Status HandleGetTupleElement(HloInstruction* gte) override { + TF_RET_CHECK(ShapeUtil::IsTuple(gte->operand(0)->shape())); return Status::OK(); - }; + } - // Send/Recv instruction must have a single user: the corresponding - // SendDone/RecvDone. with matching channel. - for (const HloComputation* computation : module.computations()) { - for (const HloInstruction* instruction : computation->instructions()) { - switch (instruction->opcode()) { - case HloOpcode::kSend: { - TF_RETURN_IF_ERROR(check_unique_host_channel(instruction)); - TF_RET_CHECK(instruction->users().size() == 1); - const HloInstruction* send_done = instruction->users().front(); - TF_RET_CHECK(send_done->opcode() == HloOpcode::kSendDone); - TF_RETURN_IF_ERROR(CheckSameChannel(instruction, send_done)); - TF_RETURN_IF_ERROR(CheckSameIsHostTransfer(instruction, send_done)); - break; - } - case HloOpcode::kRecv: { - TF_RETURN_IF_ERROR(check_unique_host_channel(instruction)); - TF_RET_CHECK(instruction->users().size() == 1); - const HloInstruction* recv_done = instruction->users().front(); - TF_RET_CHECK(recv_done->opcode() == HloOpcode::kRecvDone); - TF_RETURN_IF_ERROR(CheckSameChannel(instruction, recv_done)); - TF_RETURN_IF_ERROR(CheckSameIsHostTransfer(instruction, recv_done)); - break; + Status HandleTranspose(HloInstruction* transpose) override { + const Shape& shape = transpose->shape(); + const HloInstruction* operand = transpose->operand(0); + TF_RET_CHECK(shape.dimensions().size() == transpose->dimensions().size()); + TF_RET_CHECK(shape.dimensions().size() == + transpose->operand(0)->shape().dimensions().size()); + TF_RET_CHECK(std::equal( + operand->shape().dimensions().begin(), + operand->shape().dimensions().end(), + Permute(transpose->dimensions(), shape.dimensions()).begin())) + << "shape: " << shape << ", operand->shape(): " << shape + << ", dimensions: {" << absl::StrJoin(transpose->dimensions(), ", ") + << "}"; + return Status::OK(); + } + + Status Preprocess(HloInstruction* instruction) override { + auto previous = instructions_by_name_.find(instruction->name()); + TF_RET_CHECK(previous == instructions_by_name_.end()) + << "HLO has name that is not unique within module:\n" + << instruction->ToString() + << " in computation: " << instruction->parent()->name() + << "\nPrevious HLO with same name:\n" + << previous->second->ToString() + << " in computation: " << previous->second->parent()->name(); + instructions_by_name_[instruction->name()] = instruction; + return Status::OK(); + } + + Status Postprocess(HloInstruction* instruction) override { + if (instruction_can_change_layout_func_ && + LayoutUtil::IsDenseArray(instruction->shape()) && + !instruction_can_change_layout_func_(instruction)) { + const Shape& result_shape = instruction->shape(); + const Layout& result_layout = result_shape.layout(); + for (HloInstruction* operand : instruction->operands()) { + const Shape& operand_shape = operand->shape(); + if (LayoutUtil::IsDenseArray(operand_shape) && + ShapeUtil::Rank(operand_shape) == ShapeUtil::Rank(result_shape)) { + const Layout& operand_layout = operand_shape.layout(); + TF_RET_CHECK(LayoutUtil::Equal(result_layout, operand_layout)) + << "Instruction shouldn't change layouts " + << instruction->ToString() << " From " + << ShapeUtil::HumanString(result_shape) << " To " + << ShapeUtil::HumanString(operand_shape); } - case HloOpcode::kSendDone: - TF_RET_CHECK(instruction->operands().size() == 1); - TF_RET_CHECK(instruction->operand(0)->opcode() == HloOpcode::kSend); - break; - case HloOpcode::kRecvDone: - TF_RET_CHECK(instruction->operands().size() == 1); - TF_RET_CHECK(instruction->operand(0)->opcode() == HloOpcode::kRecv); - break; - default: - break; } } + + return Status::OK(); } - return Status::OK(); -} + + private: + absl::flat_hash_map instructions_by_name_; + // Determines whether an instruction can change layouts. + std::function + instruction_can_change_layout_func_; +}; } // namespace StatusOr HloVerifier::Run(HloModule* module) { + TF_RET_CHECK(!module->name().empty()); TF_RETURN_IF_ERROR(VerifyHloStructure(module)); TF_RETURN_IF_ERROR(VerifySendsAndRecvs(*module)); - tensorflow::gtl::FlatMap instructions; - for (auto* computation : module->computations()) { - for (const auto& instruction : computation->instructions()) { - TF_RET_CHECK(instruction->parent() == computation); - if (instruction->opcode() == HloOpcode::kFusion) { - TF_RETURN_IF_ERROR(CheckFusionInstruction(instruction)); - TF_RET_CHECK( - ContainersEqual(instruction->called_computations(), - {instruction->fused_instructions_computation()})) - << "Fusion HLO calls computations other than the " - "fused_instructions_computation: " - << instruction->ToString() - << " instruction->fused_instructions_computation(): " - << instruction->fused_instructions_computation()->ToString() - << " instruction->called_computations(): " - << ComputationsToString(instruction->called_computations()); - - for (const auto& fused : instruction->fused_instructions()) { - TF_RET_CHECK(fused->parent() == - instruction->fused_instructions_computation()) - << "Fused HLO was missing a parent: " << fused->ToString() - << " parent: " << fused->parent() - << " computation: " << computation; - } - } else if (instruction->opcode() == HloOpcode::kBroadcast) { - // If you see this failure then someone has confused the difference - // between the HLO broadcast op, and the UserComputation broadcast - // op. See https://groups.google.com/forum/#!topic/xla-dev/9LqijHmTt_I - // or ComputationLowerer::Visit() - TF_RET_CHECK(instruction->dimensions().size() == - ShapeUtil::Rank(instruction->operand(0)->shape())) - << "Broadcast HLO (" << instruction->ToShortString() - << ") has invalid number of dimensions: " - << instruction->dimensions().size() - << " != " << ShapeUtil::Rank(instruction->operand(0)->shape()); - } else if (instruction->opcode() == HloOpcode::kWhile) { - TF_RETURN_IF_ERROR(CheckWhileInstruction(instruction)); - } else if (instruction->opcode() == HloOpcode::kConditional) { - TF_RETURN_IF_ERROR(CheckConditionalInstruction(instruction)); - } else if (instruction->opcode() != - HloOpcode::kRng /* Rng operands are always scalar. */ - && instruction->IsElementwise()) { - TF_RETURN_IF_ERROR(CheckElementwiseInstruction(instruction)); - } - - auto previous = instructions.find(instruction->name()); - TF_RET_CHECK(previous == instructions.end()) - << "HLO has name that is not unique within module:\n" - << instruction->ToString() - << " in computation: " << computation->name() - << "\nPrevious HLO with same name:\n" - << previous->second->ToString() - << " in computation: " << previous->second->parent()->name(); - instructions[instruction->name()] = instruction; - } - std::unique_ptr shape_verifier = shape_verifier_factory_(); TF_RETURN_IF_ERROR(computation->Accept(shape_verifier.get())); + + InstructionVerifier instruction_verifier( + instruction_can_change_layout_func_); + TF_RETURN_IF_ERROR(computation->Accept(&instruction_verifier)); } + TF_RETURN_IF_ERROR(CheckEntryComputationLayout(*module)); TF_RETURN_IF_ERROR(VerifyEntryAndExitShapes(*module)); + // If the module has a schedule, it must be valid. + if (module->has_schedule()) { + TF_RETURN_IF_ERROR(module->schedule().Verify()); + } + + TF_RETURN_IF_ERROR(module->input_output_alias_config().Verify(*module)); + return false; } diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h index b6093d667c3b99873ccd03b8048abded2ce30457..e1f3402465746b0478d7bb7e4ee2b66e3f839eb2 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.h +++ b/tensorflow/compiler/xla/service/hlo_verifier.h @@ -32,6 +32,8 @@ class ShapeVerifier : public DfsHloVisitor { : layout_sensitive_(layout_sensitive), allow_mixed_precision_(allow_mixed_precision) {} + Status Preprocess(HloInstruction* hlo) override; + Status HandleElementwiseUnary(HloInstruction* hlo) override; Status HandleElementwiseBinary(HloInstruction* hlo) override; Status HandleClamp(HloInstruction* clamp) override; @@ -47,6 +49,7 @@ class ShapeVerifier : public DfsHloVisitor { Status HandleFft(HloInstruction* fft) override; Status HandleCrossReplicaSum(HloInstruction* crs) override; Status HandleAllToAll(HloInstruction* hlo) override; + Status HandleCollectivePermute(HloInstruction* hlo) override; Status HandleReducePrecision(HloInstruction* reduce_precision) override; Status HandleInfeed(HloInstruction*) override; Status HandleOutfeed(HloInstruction*) override; @@ -150,15 +153,21 @@ class ShapeVerifier : public DfsHloVisitor { // HLO pass that verifies invariants of HLO instructions for each computation in // the module. -class HloVerifier : public HloPassInterface { +class HloVerifier : public HloModulePass { public: using ShapeVerifierFactory = std::function()>; - explicit HloVerifier(bool layout_sensitive, bool allow_mixed_precision) + explicit HloVerifier(bool layout_sensitive, bool allow_mixed_precision, + std::function + instruction_can_change_layout_func = {}) : shape_verifier_factory_([layout_sensitive, allow_mixed_precision] { return absl::make_unique(layout_sensitive, allow_mixed_precision); - }) {} + }), + instruction_can_change_layout_func_( + std::move(instruction_can_change_layout_func)) { + CHECK(instruction_can_change_layout_func_ == nullptr || layout_sensitive); + } // Uses custom shape verification. explicit HloVerifier(ShapeVerifierFactory shape_verifier_factory) @@ -171,22 +180,15 @@ class HloVerifier : public HloPassInterface { StatusOr Run(HloModule* module) override; private: - // CHECKs various invariants of a fusion instruction. - Status CheckFusionInstruction(HloInstruction* fusion) const; - - Status CheckWhileInstruction(HloInstruction* instruction); - - Status CheckConditionalInstruction(HloInstruction* instruction); - - // Checks that the non-scalar operand shapes are compatible to the output - // shape, i.e., that there are no implicit broadcasts of size-one dimensions. - Status CheckElementwiseInstruction(HloInstruction* instruction); - // Creates a ShapeVerifier that checks that shapes match inferred // expectations. This is a factory function because ShapeVerifier, // being a DfsHloVisitor, is stateful. We want a clean object // for each run of the verifier. ShapeVerifierFactory shape_verifier_factory_; + + // Determines whether an instruction can change layouts. + std::function + instruction_can_change_layout_func_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_verifier_test.cc b/tensorflow/compiler/xla/service/hlo_verifier_test.cc index 70b741353d043bbe6bcc6d4bf55e9cf9d0d8d3c3..afe01e5487c3225815e01343d86e9fe894c2cde8 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier_test.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier_test.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/service/layout_assignment.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" @@ -34,6 +35,8 @@ namespace { using ::testing::HasSubstr; +// This class cannot be converted to use HloVerifiedTestBase. It explicitly +// uses HloTestBase to create and test malformed HLOs. class HloVerifierTest : public HloTestBase { public: HloVerifierTest() @@ -48,6 +51,14 @@ class HloVerifierTestAllowMixedPrecision : public HloTestBase { /*allow_mixed_precision_in_hlo_verifier=*/true) {} }; +class HloVerifierTestLayoutSensitive : public HloTestBase { + public: + HloVerifierTestLayoutSensitive() + : HloTestBase(/*verifier_layout_sensitive=*/true, + /*allow_mixed_precision_in_hlo_verifier=*/false, + LayoutAssignment::InstructionCanChangeLayout) {} +}; + TEST_F(HloVerifierTest, NullInstructionParent) { HloComputation::Builder builder(TestName()); const Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); @@ -277,5 +288,142 @@ TEST_F(HloVerifierTest, RngElementTypeNotSupported) { EXPECT_THAT(status.error_message(), HasSubstr("Element type not supported")); } +TEST_F(HloVerifierTest, NegativeInteriorPaddingNotAllowed) { + // This testcase can't be written using textual HLO, because it doesn't parse + // negative interior padding. That's probably a feature. :) + HloComputation::Builder builder(TestName()); + HloInstruction* param = + builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {100}), "param")); + PaddingConfig padding_config; + padding_config.add_dimensions()->set_interior_padding(-1); + builder.AddInstruction(HloInstruction::CreatePad( + ShapeUtil::MakeShape(F32, {100}), param, + builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::Zero(F32))), + padding_config)); + + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); + + auto status = verifier().Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), + HasSubstr("Interior padding cannot be negative")); +} + +TEST_F(HloVerifierTest, PadNegativeInteriorDilationNotAllowed) { + // This testcase can't be written using textual HLO, because it doesn't parse + // negative interior padding. That's probably a feature. :) + HloComputation::Builder builder(TestName()); + HloInstruction* param = + builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {100}), "param")); + PaddingConfig padding_config; + padding_config.add_dimensions()->set_interior_padding(-1); + builder.AddInstruction(HloInstruction::CreatePad( + ShapeUtil::MakeShape(F32, {100}), param, + builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::Zero(F32).Clone())), + padding_config)); + + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); + + EXPECT_THAT(verifier().Run(module.get()).status().error_message(), + HasSubstr("Interior padding cannot be negative")); +} + +// Simple module containing a convolution as the root. +static const char* const kConvHloString = R"( +HloModule module +ENTRY entry_computation { + param0 = f16[128,128,56,56] parameter(0) + param1 = f16[3,3,128,128] parameter(1) + zero_f16 = f16[] constant(0) + ROOT conv = f16[128,128,28,28] convolution(param0, param1), + window={size=3x3 stride=2x2}, dim_labels=bf01_01io->bf01 +})"; + +TEST_F(HloVerifierTest, ConvNegativeWindowDilationNotAllowed) { + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(kConvHloString)); + auto* conv = module->entry_computation()->root_instruction(); + Window w = conv->window(); + w.mutable_dimensions(0)->set_window_dilation(-1); + conv->set_window(w); + + EXPECT_THAT(verifier().Run(module.get()).status().error_message(), + HasSubstr("non-positive window dilation factor")); +} + +TEST_F(HloVerifierTest, ConvNegativeBaseDilationNotAllowed) { + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(kConvHloString)); + auto* conv = module->entry_computation()->root_instruction(); + Window w = conv->window(); + w.mutable_dimensions(0)->set_base_dilation(-1); + conv->set_window(w); + + EXPECT_THAT(verifier().Run(module.get()).status().error_message(), + HasSubstr("non-positive base area dilation factor")); +} + +static const char* const kAddWithLayoutChangeHlo = R"( + HloModule AddWithLayoutChange + ENTRY AddWithLayoutChange { + par0 = f32[3,4]{1,0} parameter(0) + par1 = f32[3,4]{0,1} parameter(1) + ROOT add0 = f32[3,4]{1,0} add(par0,par1) + } + )"; + +TEST_F(HloVerifierTest, AddWithLayoutChange) { + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(kAddWithLayoutChangeHlo)); + auto status = verifier().Run(module.get()).status(); + ASSERT_TRUE(status.ok()); +} + +TEST_F(HloVerifierTestLayoutSensitive, AddWithLayoutChangeNotAllowed) { + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(kAddWithLayoutChangeHlo)); + auto status = verifier().Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), + HasSubstr("Instruction shouldn't change layouts")); +} + +TEST_F(HloVerifierTestLayoutSensitive, SliceWithLayoutChangeNotAllowed) { + const char* const kSliceWithLayoutChangeHlo = R"( + HloModule SliceWithLayoutChange + ENTRY SliceWithLayoutChange { + par0 = f32[4,5]{0,1} parameter(0) + par1 = s32[2] parameter(1) + ROOT dslice0 = f32[3,4]{1,0} dynamic-slice(par0, par1), + dynamic_slice_sizes={3,4} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseHloString(kSliceWithLayoutChangeHlo)); + auto status = verifier().Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), + HasSubstr("Instruction shouldn't change layouts")); +} + +TEST_F(HloVerifierTestLayoutSensitive, ConcatWithLayoutChangeNotAllowed) { + const char* const kConcatWithLayoutChangeHlo = R"( + HloModule ConcatWithLayoutChange + ENTRY ConcatWithLayoutChange { + par0 = f32[3,5]{0,1} parameter(0) + par1 = f32[3,3]{1,0} parameter(1) + ROOT concat0 = f32[3,8]{1,0} concatenate(f32[3,5] par0, f32[3,3] par1), + dimensions={1} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseHloString(kConcatWithLayoutChangeHlo)); + auto status = verifier().Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), + HasSubstr("Instruction shouldn't change layouts")); +} } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/human_readable_profile_builder.cc b/tensorflow/compiler/xla/service/human_readable_profile_builder.cc index 581b3ce1e062dd0e15823bbbdc2fce808ee4bcfd..e76b93107c923b41666f6b0a388dda143a8cb50a 100644 --- a/tensorflow/compiler/xla/service/human_readable_profile_builder.cc +++ b/tensorflow/compiler/xla/service/human_readable_profile_builder.cc @@ -15,26 +15,26 @@ limitations under the License. #include "tensorflow/compiler/xla/service/human_readable_profile_builder.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/metric_table_report.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/strings/numbers.h" -#include "tensorflow/core/lib/strings/stringprintf.h" namespace xla { using absl::StrAppend; +using absl::StrAppendFormat; using absl::StrCat; -using tensorflow::strings::Appendf; +using absl::StrFormat; using tensorflow::strings::HumanReadableElapsedTime; using tensorflow::strings::HumanReadableNumBytes; -using tensorflow::strings::Printf; string HumanReadableProfileBuilder::ToString() const { string s; - Appendf(&s, "Execution profile for %s: (%s @ f_nom)\n", - computation_name_.c_str(), - HumanReadableElapsedTime(CyclesToSeconds(total_cycles_)).c_str()); + StrAppendFormat(&s, "Execution profile for %s: (%s @ f_nom)\n", + computation_name_, + HumanReadableElapsedTime(CyclesToSeconds(total_cycles_))); int64 cumulative_cycles = 0; auto print_op = [&](const OpInfo& op, bool is_total = false) { @@ -56,7 +56,7 @@ string HumanReadableProfileBuilder::ToString() const { if (op.bytes_accessed > op.cycles) { bytes_per_cycle = StrCat(HumanReadableNumBytes(bpc), "/cycle"); } else { - bytes_per_cycle = Printf("%.3fB/cycle", bpc); + bytes_per_cycle = StrFormat("%.3fB/cycle", bpc); } } @@ -77,27 +77,24 @@ string HumanReadableProfileBuilder::ToString() const { // columns in the output. cycles_percent_str = "100.% 100Σ"; } else { - cycles_percent_str = - Printf("%5.2f%% %2.0fΣ", cycles_percent, cumulative_cycles_percent); + cycles_percent_str = StrFormat("%5.2f%% %2.0fΣ", cycles_percent, + cumulative_cycles_percent); } double nsecs = op.cycles / clock_rate_ghz_; - Appendf( + StrAppendFormat( &s, - "%15lld cycles (%s) :: %12.1f usec %22s :: %18s :: %18s :: %14s :: " + "%15d cycles (%s) :: %12.1f usec %22s :: %18s :: %18s :: %14s :: " "%16s :: %s\n", - op.cycles, cycles_percent_str.c_str(), CyclesToMicroseconds(op.cycles), + op.cycles, cycles_percent_str, CyclesToMicroseconds(op.cycles), op.optimal_seconds < 0 ? "" - : Printf("(%12.1f optimal)", op.optimal_seconds * 1e6).c_str(), - op.flop_count <= 0 - ? "" - : HumanReadableNumFlops(op.flop_count, nsecs).c_str(), + : StrFormat("(%12.1f optimal)", op.optimal_seconds * 1e6), + op.flop_count <= 0 ? "" : HumanReadableNumFlops(op.flop_count, nsecs), op.transcendental_count <= 0 ? "" - : HumanReadableNumTranscendentalOps(op.transcendental_count, nsecs) - .c_str(), - bytes_per_sec.c_str(), bytes_per_cycle.c_str(), op.name.c_str()); + : HumanReadableNumTranscendentalOps(op.transcendental_count, nsecs), + bytes_per_sec, bytes_per_cycle, op.name); }; float optimal_seconds_sum = 0.0; diff --git a/tensorflow/compiler/xla/service/human_readable_profile_builder.h b/tensorflow/compiler/xla/service/human_readable_profile_builder.h index b99624460e3f93fd08166358ac9f454e9a145075..925111fa1f1e48650b0089f402d92e431043eabe 100644 --- a/tensorflow/compiler/xla/service/human_readable_profile_builder.h +++ b/tensorflow/compiler/xla/service/human_readable_profile_builder.h @@ -32,7 +32,7 @@ class HumanReadableProfileBuilder { explicit HumanReadableProfileBuilder(absl::string_view computation_name, int64 total_cycles, double clock_rate_ghz) - : computation_name_(std::string(computation_name)), + : computation_name_(computation_name), total_cycles_(total_cycles), clock_rate_ghz_(clock_rate_ghz) { CHECK_GE(clock_rate_ghz, 1e-9); @@ -47,10 +47,9 @@ class HumanReadableProfileBuilder { absl::string_view category, int64 cycles, int64 flop_count, int64 transcendental_count, int64 bytes_accessed, float optimal_seconds) { - op_infos_.push_back({std::string(op_name), std::string(short_name), - std::string(category), cycles, flop_count, - transcendental_count, bytes_accessed, - optimal_seconds}); + op_infos_.push_back({string(op_name), string(short_name), string(category), + cycles, flop_count, transcendental_count, + bytes_accessed, optimal_seconds}); } // Gets the human-readable profile. diff --git a/tensorflow/compiler/xla/service/implicit_broadcast_remover.h b/tensorflow/compiler/xla/service/implicit_broadcast_remover.h index 85bb4a8b2450a48d461f1d84e0609a38a6818d9c..9c48b7db613b049536c76237b4cfebbbc47448f3 100644 --- a/tensorflow/compiler/xla/service/implicit_broadcast_remover.h +++ b/tensorflow/compiler/xla/service/implicit_broadcast_remover.h @@ -25,7 +25,7 @@ namespace xla { // Pass which replaces all implicit broadcasts with their equivalent sequence of // explicit broadcast and reshape instructions. -class ImplicitBroadcastRemover : public HloPassInterface { +class ImplicitBroadcastRemover : public HloModulePass { public: ImplicitBroadcastRemover() {} ~ImplicitBroadcastRemover() override {} diff --git a/tensorflow/compiler/xla/service/implicit_broadcast_remover_test.cc b/tensorflow/compiler/xla/service/implicit_broadcast_remover_test.cc index df88587492e256b5a4176971b2f443fda8f43421..f85d31d5225b8012b68f851b2bfec219d736ba0d 100644 --- a/tensorflow/compiler/xla/service/implicit_broadcast_remover_test.cc +++ b/tensorflow/compiler/xla/service/implicit_broadcast_remover_test.cc @@ -26,11 +26,6 @@ namespace xla { namespace { class ImplicitBroadcastRemoverTest : public HloVerifiedTestBase { - public: - ImplicitBroadcastRemoverTest() - : HloVerifiedTestBase(/*layout_sensitive=*/false, - /*allow_mixed_precision=*/false) {} - protected: ImplicitBroadcastRemover remover_; }; diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.cc b/tensorflow/compiler/xla/service/indexed_array_analysis.cc index 43ef30d1eb645b5d12c1776f8fef28d00452349c..1ebb3319779c00fd4afe90606bf336e16349429d 100644 --- a/tensorflow/compiler/xla/service/indexed_array_analysis.cc +++ b/tensorflow/compiler/xla/service/indexed_array_analysis.cc @@ -16,6 +16,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/indexed_array_analysis.h" #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" @@ -23,7 +25,6 @@ limitations under the License. #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/hlo_evaluator.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/gtl/flatset.h" namespace xla { namespace gtl = ::tensorflow::gtl; @@ -35,7 +36,6 @@ using ConstantArray = Analysis::ConstantArray; using ReshapedArray = Analysis::ReshapedArray; using ScalarIndexedArray = Analysis::ScalarIndexedArray; using absl::StrJoin; -using tensorflow::gtl::ArraySlice; } // namespace string IndexedArrayAnalysis::ToString(Array* root, bool print_constants) { @@ -96,7 +96,7 @@ Status IndexedArrayAnalysis::TraverseAndPopulateCache( absl::InlinedVector stack; enum DfsState { kDiscovered, kVisited }; - gtl::FlatMap dfs_state_map; + absl::flat_hash_map dfs_state_map; stack.push_back(root); InsertOrDie(&dfs_state_map, root, kDiscovered); @@ -166,6 +166,7 @@ StatusOr IndexedArrayAnalysis::ComputeArrayFor( TF_ASSIGN_OR_RETURN( computed_array, ComputeArrayForDot(instr->shape(), instr->dot_dimension_numbers(), + instr->precision_config(), FindOrDie(cache_, instr->operand(0)), FindOrDie(cache_, instr->operand(1)))); } else { @@ -186,7 +187,7 @@ StatusOr IndexedArrayAnalysis::ComputeArrayForConstant( StatusOr IndexedArrayAnalysis::FoldGatherOfGather( ScalarIndexedArray* source, Array* indices, int64 source_dim, - tensorflow::gtl::ArraySlice output_dims, Shape shape) { + absl::Span output_dims, Shape shape) { // We want to transform Gather(Gather(A, X), Y) => Gather(A, Gather(X, Y)). // `source` is the inner Gather(A, X). @@ -252,8 +253,7 @@ StatusOr IndexedArrayAnalysis::FoldGatherOfGather( StatusOr IndexedArrayAnalysis::ComputeArrayForGather( const Shape& shape, const GatherDimensionNumbers& dim_numbers, - tensorflow::gtl::ArraySlice slice_sizes, Array* source, - Array* indices) { + absl::Span slice_sizes, Array* source, Array* indices) { if (dim_numbers.index_vector_dim() != indices->shape().dimensions_size()) { VLOG(3) << "ComputeArrayForGather: indices are not scalar"; return nullptr; @@ -314,7 +314,7 @@ namespace { // Returns an index into `values` such that the product of the range // [values.begin()+index, values.end()) is equal to `product`. If there is no // such index, return -1. All integers in `values` must be positive. -int64 FindSuffixWithProduct(ArraySlice values, int64 product) { +int64 FindSuffixWithProduct(absl::Span values, int64 product) { DCHECK(absl::c_all_of(values, [](int64 value) { return value > 0; })); int64 current_product = 1; @@ -343,7 +343,8 @@ struct ReshapePassthroughDimPair { // The returned vector of pairs is sorted in both the result_dim and the // operand_dim components. std::vector ComputeReshapePassthroughDimPairs( - ArraySlice operand_shape, ArraySlice result_shape) { + absl::Span operand_shape, + absl::Span result_shape) { // A reshape can be seen as an index mapping from output index to input index: // // (i_0, ..., i_n) = f(o_0, ..., o_m) @@ -420,7 +421,7 @@ std::vector ComputeReshapePassthroughDimPairs( // Return true if `dim` is stated as an passthrough operand dim in // `passthrough_dims`. bool IsReshapePassthroughOperandDim( - ArraySlice passthrough_dims, int64 dim) { + absl::Span passthrough_dims, int64 dim) { return absl::c_any_of(passthrough_dims, [&](ReshapePassthroughDimPair passthrough_dim_pair) { return passthrough_dim_pair.operand_dim == dim; @@ -430,7 +431,8 @@ bool IsReshapePassthroughOperandDim( // Maps `operand_dim` which must be an passthrough operand dimension to its // corresponding passthrough result dimension based on `passthrough_dims`. int64 MapPassthroughOperandDimToResultDim( - ArraySlice passthrough_dims, int64 operand_dim) { + absl::Span passthrough_dims, + int64 operand_dim) { auto it = absl::c_find_if( passthrough_dims, [&](ReshapePassthroughDimPair passthrough_dim_pair) { return passthrough_dim_pair.operand_dim == operand_dim; @@ -439,9 +441,9 @@ int64 MapPassthroughOperandDimToResultDim( return it->result_dim; } -int64 FindSourcePositionForPassthroughResultDim(ArraySlice operand_shape, - ArraySlice result_shape, - int64 source_passthrough_dim) { +int64 FindSourcePositionForPassthroughResultDim( + absl::Span operand_shape, absl::Span result_shape, + int64 source_passthrough_dim) { VLOG(3) << "FindSourcePositionForPassthroughResultDim([" << StrJoin(operand_shape, ",") << "], [" << StrJoin(result_shape, ",") << "], " << source_passthrough_dim << ")"; @@ -499,7 +501,7 @@ IndexedArrayAnalysis::ReshapeToRemoveDegenerateDims( for (int64 i = 0, e = shape.dimensions_size(); i < e; i++) { if (shape.dimensions(i) == 1) { degenerate_dims_seen++; - } else if (ArrayContains(operand->output_dims(), i)) { + } else if (absl::c_linear_search(operand->output_dims(), i)) { new_output_dims.push_back(i - degenerate_dims_seen); } } @@ -519,8 +521,7 @@ IndexedArrayAnalysis::ReshapeToRemoveDegenerateDims( } StatusOr IndexedArrayAnalysis::ReshapeToAddDegenerateDims( - ScalarIndexedArray* operand, - tensorflow::gtl::ArraySlice degenerate_dims) { + ScalarIndexedArray* operand, absl::Span degenerate_dims) { if (degenerate_dims.empty()) { return operand; } @@ -873,7 +874,7 @@ IndexedArrayAnalysis::ComputeArrayForElementwiseBinaryOp(HloOpcode opcode, return nullptr; } - ArraySlice broadcast_dims = broadcast_instr->dimensions(); + absl::Span broadcast_dims = broadcast_instr->dimensions(); auto is_broadcasted_dim = [&](int64 output_dim) { return absl::c_find(broadcast_dims, output_dim) == broadcast_dims.end(); }; @@ -896,7 +897,7 @@ IndexedArrayAnalysis::ComputeArrayForElementwiseBinaryOp(HloOpcode opcode, // The scalar-indexed node "removes" the source dim and "inserts" the output // dims. We do the opposite here to undo the scalar-indexed operation. - ArraySlice output_dims = scalar_indexed_const->output_dims(); + absl::Span output_dims = scalar_indexed_const->output_dims(); for (int64 i = output_dims.size() - 1; i >= 0; --i) { CHECK(simulated_index[output_dims[i]] == IndexComponent::Broadcasted); EraseAt(&simulated_index, output_dims[i]); @@ -918,7 +919,7 @@ IndexedArrayAnalysis::ComputeArrayForElementwiseBinaryOp(HloOpcode opcode, // inner_broadcast_result is the Broadcast'(Const0) bit in // BinaryOp(Broadcast'(Const0), Const1) TF_ASSIGN_OR_RETURN( - std::unique_ptr inner_broadcast_result, + Literal inner_broadcast_result, broadcast_const_operand->literal().Broadcast( scalar_indexed_const->source()->shape(), new_inner_broadcast_dims)); @@ -928,12 +929,12 @@ IndexedArrayAnalysis::ComputeArrayForElementwiseBinaryOp(HloOpcode opcode, TF_ASSIGN_OR_RETURN( literal_for_new_source, TakeOwnership(HloEvaluator{}.EvaluateElementwiseBinaryOp( - opcode, scalar_indexed_const->literal(), *inner_broadcast_result))); + opcode, scalar_indexed_const->literal(), inner_broadcast_result))); } else { TF_ASSIGN_OR_RETURN( literal_for_new_source, TakeOwnership(HloEvaluator{}.EvaluateElementwiseBinaryOp( - opcode, *inner_broadcast_result, scalar_indexed_const->literal()))); + opcode, inner_broadcast_result, scalar_indexed_const->literal()))); } ConstantArray* new_source = Construct(literal_for_new_source); @@ -973,12 +974,12 @@ namespace { // Returns the non-contracting non-batch dimension (as per `contracting_dims` // and `batch_dims`) if there is exactly one, otherwise returns nullopt. absl::optional GetOnlyNonContractingNonBatchDim( - int64 rank, ArraySlice contracting_dims, - ArraySlice batch_dims) { + int64 rank, absl::Span contracting_dims, + absl::Span batch_dims) { absl::optional result; for (int64 dim = 0; dim < rank; dim++) { - if (!ArrayContains(contracting_dims, dim) && - !ArrayContains(batch_dims, dim)) { + if (!absl::c_linear_search(contracting_dims, dim) && + !absl::c_linear_search(batch_dims, dim)) { if (result.has_value()) { return absl::nullopt; } @@ -998,7 +999,8 @@ absl::optional GetOnlyNonContractingNonBatchDim( // of whatever operand `indexed_array` is to the dot (LHS or RHS). bool CanFoldDotIntoIndexedArray( absl::string_view tag, Analysis::ScalarIndexedConstantArray* indexed_array, - ArraySlice contracting_dims, ArraySlice batch_dims) { + absl::Span contracting_dims, + absl::Span batch_dims) { absl::optional non_contracting_non_batch_dim = GetOnlyNonContractingNonBatchDim(ShapeUtil::Rank(indexed_array->shape()), contracting_dims, batch_dims); @@ -1030,7 +1032,8 @@ bool CanFoldDotIntoIndexedArray( StatusOr IndexedArrayAnalysis::ComputeArrayForDotWithIndexedLhs( const Shape& shape, const DotDimensionNumbers& dim_numbers, - ScalarIndexedConstantArray* lhs, ConstantArray* rhs) { + const PrecisionConfig& precision_config, ScalarIndexedConstantArray* lhs, + ConstantArray* rhs) { VLOG(3) << "ComputeArrayForDotWithIndexedLhs(" << ToString(lhs) << " " << ToString(rhs); if (!CanFoldDotIntoIndexedArray( @@ -1045,9 +1048,10 @@ IndexedArrayAnalysis::ComputeArrayForDotWithIndexedLhs( new_dim_numbers.set_lhs_contracting_dimensions( 0, lhs->source_dim() == (lhs_rank - 1) ? (lhs_rank - 2) : (lhs_rank - 1)); - TF_ASSIGN_OR_RETURN(Literal * literal_for_new_source, - TakeOwnership(HloEvaluator{}.EvaluateDotOp( - new_dim_numbers, lhs->literal(), *rhs->literal()))); + TF_ASSIGN_OR_RETURN( + Literal * literal_for_new_source, + TakeOwnership(HloEvaluator{}.EvaluateDotOp( + new_dim_numbers, precision_config, lhs->literal(), *rhs->literal()))); // The new source dimension is wherever the non-batch non-contracting LHS // dimension "went". @@ -1063,7 +1067,8 @@ IndexedArrayAnalysis::ComputeArrayForDotWithIndexedLhs( StatusOr IndexedArrayAnalysis::ComputeArrayForDotWithIndexedRhs( const Shape& shape, const DotDimensionNumbers& dim_numbers, - ConstantArray* lhs, ScalarIndexedConstantArray* rhs) { + const PrecisionConfig& precision_config, ConstantArray* lhs, + ScalarIndexedConstantArray* rhs) { VLOG(3) << "ComputeArrayForDotWithIndexedRhs(" << ToString(lhs) << " " << ToString(rhs); if (!CanFoldDotIntoIndexedArray( @@ -1079,9 +1084,10 @@ IndexedArrayAnalysis::ComputeArrayForDotWithIndexedRhs( new_dim_numbers.set_rhs_contracting_dimensions( 0, rhs->source_dim() == (rhs_rank - 1) ? (rhs_rank - 2) : (rhs_rank - 1)); - TF_ASSIGN_OR_RETURN(Literal * literal_for_new_source, - TakeOwnership(HloEvaluator{}.EvaluateDotOp( - new_dim_numbers, *lhs->literal(), rhs->literal()))); + TF_ASSIGN_OR_RETURN( + Literal * literal_for_new_source, + TakeOwnership(HloEvaluator{}.EvaluateDotOp( + new_dim_numbers, precision_config, *lhs->literal(), rhs->literal()))); // The new source dimension is wherever the non-batch non-contracting RHS // dimension "went". @@ -1095,8 +1101,8 @@ IndexedArrayAnalysis::ComputeArrayForDotWithIndexedRhs( } StatusOr IndexedArrayAnalysis::ComputeArrayForDot( - const Shape& shape, const DotDimensionNumbers& dim_numbers, Array* lhs, - Array* rhs) { + const Shape& shape, const DotDimensionNumbers& dim_numbers, + const PrecisionConfig& precision_config, Array* lhs, Array* rhs) { // Intuitively, if // // - The LHS of a dot product is a gathered sequence of rows from a constant @@ -1119,6 +1125,7 @@ StatusOr IndexedArrayAnalysis::ComputeArrayForDot( dynamic_cast(lhs)) { if (auto* rhs_constant = dynamic_cast(rhs)) { return ComputeArrayForDotWithIndexedLhs(shape, dim_numbers, + precision_config, lhs_indexed_array, rhs_constant); } } @@ -1126,7 +1133,8 @@ StatusOr IndexedArrayAnalysis::ComputeArrayForDot( if (auto* rhs_indexed_array = dynamic_cast(rhs)) { if (auto* lhs_constant = dynamic_cast(lhs)) { - return ComputeArrayForDotWithIndexedRhs(shape, dim_numbers, lhs_constant, + return ComputeArrayForDotWithIndexedRhs(shape, dim_numbers, + precision_config, lhs_constant, rhs_indexed_array); } } diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.h b/tensorflow/compiler/xla/service/indexed_array_analysis.h index 3fa7d749e1984cc5d7249499e304593b5413cfe2..e5aa67fd850de647652d66017223e19fb92cc068 100644 --- a/tensorflow/compiler/xla/service/indexed_array_analysis.h +++ b/tensorflow/compiler/xla/service/indexed_array_analysis.h @@ -18,10 +18,10 @@ limitations under the License. #include +#include "absl/container/flat_hash_map.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" -#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/util/ptr_util.h" namespace xla { @@ -188,9 +188,7 @@ class IndexedArrayAnalysis { // `output_dims` are the dimensions in the output array that are being used // to compute an index into the `indices` array. See the class // documentation and the overview for more details. - tensorflow::gtl::ArraySlice output_dims() const { - return output_dims_; - } + absl::Span output_dims() const { return output_dims_; } private: explicit ScalarIndexedArray(Array* source, Array* indices, int64 source_dim, @@ -265,19 +263,21 @@ class IndexedArrayAnalysis { StatusOr ComputeArrayForGather( const Shape& shape, const GatherDimensionNumbers& dim_numbers, - tensorflow::gtl::ArraySlice slice_sizes, Array* source, - Array* indices); + absl::Span slice_sizes, Array* source, Array* indices); StatusOr ComputeArrayForDotWithIndexedLhs( const Shape& shape, const DotDimensionNumbers& dim_numbers, - ScalarIndexedConstantArray* lhs, ConstantArray* rhs); + const PrecisionConfig& precision_config, ScalarIndexedConstantArray* lhs, + ConstantArray* rhs); StatusOr ComputeArrayForDotWithIndexedRhs( const Shape& shape, const DotDimensionNumbers& dim_numbers, - ConstantArray* lhs, ScalarIndexedConstantArray* rhs); + const PrecisionConfig& precision_config, ConstantArray* lhs, + ScalarIndexedConstantArray* rhs); StatusOr ComputeArrayForDot(const Shape& shape, const DotDimensionNumbers& dim_numbers, + const PrecisionConfig& precision_config, Array* lhs, Array* rhs); // This tries to fold a ScalarIndexedArray which has another @@ -303,7 +303,7 @@ class IndexedArrayAnalysis { // G1 = [Arr[i] for i in I2] StatusOr FoldGatherOfGather( ScalarIndexedArray* source, Array* indices, int64 source_dim, - tensorflow::gtl::ArraySlice output_dims, Shape shape); + absl::Span output_dims, Shape shape); // Reshapes a scalar-indexed node to remove the degenerate dimensions in its // output. The result is always a scalar-indexed node. @@ -313,8 +313,7 @@ class IndexedArrayAnalysis { // Reshapes a scalar-indexed node such that the result has the degenerate // dimensions `degenerate_dims`. The result is always a scalar-indexed node. StatusOr ReshapeToAddDegenerateDims( - ScalarIndexedArray* operand, - tensorflow::gtl::ArraySlice degenerate_dims); + ScalarIndexedArray* operand, absl::Span degenerate_dims); StatusOr FoldReshapeOfGather( const Shape& shape, ScalarIndexedConstantArray* operand); @@ -348,28 +347,26 @@ class IndexedArrayAnalysis { } } - Literal* TakeOwnership(std::unique_ptr literal) { + Literal* TakeOwnership(Literal literal) { owned_literals_.push_back(std::move(literal)); - return owned_literals_.back().get(); + return &owned_literals_.back(); } - StatusOr TakeOwnership( - StatusOr> literal_or_error) { - TF_ASSIGN_OR_RETURN(std::unique_ptr literal, - std::move(literal_or_error)); + StatusOr TakeOwnership(StatusOr literal_or_error) { + TF_ASSIGN_OR_RETURN(Literal literal, std::move(literal_or_error)); owned_literals_.push_back(std::move(literal)); - return owned_literals_.back().get(); + return &owned_literals_.back(); } std::vector> owned_tensors_; - std::vector> owned_literals_; - tensorflow::gtl::FlatMap cache_; + std::vector owned_literals_; + absl::flat_hash_map cache_; }; // A pass that prints all non-trivial results returned by IndexedArrayAnalysis. // This pass is a no-op if !VLOG_IS_ON(2) so it should be fine to // unconditionally add to the regular HLO pass pipeline. -class IndexedArrayAnalysisPrinterPass : public HloPassInterface { +class IndexedArrayAnalysisPrinterPass : public HloModulePass { public: absl::string_view name() const override; StatusOr Run(HloModule* module) override; diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc b/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc index c34c32f7d3361efbfca1fdfe5c286a4c03b5dc60..2d03aebc1aca4c55cca588072233b7a18e70a306 100644 --- a/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc +++ b/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc @@ -22,11 +22,6 @@ limitations under the License. namespace xla { namespace { class IndexedArrayAnalysisTest : public HloVerifiedTestBase { - public: - IndexedArrayAnalysisTest() - : HloVerifiedTestBase(/*layout_sensitive=*/false, - /*allow_mixed_precision=*/false) {} - protected: void AssertArrayForRootExpressionIs(const string& hlo_text, const string& root_expression) { diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index be59ce82816c1c30e079449599406705a55400c0..69a4c160ee5c4539272c3085338dc6de1b9347ff 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -22,10 +22,12 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/service/fusion_queue.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/platform/logging.h" namespace xla { @@ -122,6 +124,7 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) { case HloOpcode::kConvolution: case HloOpcode::kCrossReplicaSum: case HloOpcode::kAllToAll: + case HloOpcode::kCollectivePermute: case HloOpcode::kCustomCall: case HloOpcode::kDivide: case HloOpcode::kDomain: @@ -171,7 +174,8 @@ bool InstructionFusion::EffectivelyAtMostUnary(HloInstruction* hlo) { }); return std::count_if(hlo->operands().begin(), hlo->operands().end(), [output_rank](HloInstruction* operand) { - if (operand->opcode() == HloOpcode::kBroadcast) { + if (operand->opcode() == HloOpcode::kBroadcast || + operand->opcode() == HloOpcode::kIota) { return false; } if (operand->opcode() == HloOpcode::kConstant && @@ -185,39 +189,49 @@ bool InstructionFusion::EffectivelyAtMostUnary(HloInstruction* hlo) { bool InstructionFusion::CanFuseOnAllPaths( HloInstruction* producer, HloInstruction* consumer, - const HloInstructionSet& do_not_duplicate) { + const HloInstructionSet& do_not_fuse, + absl::flat_hash_map, bool>* + result_cache) { if (consumer == producer) { return true; } - if (!consumer->IsFusable()) { + if (!consumer->IsFusible()) { return false; } + auto cache_it = result_cache->find(std::make_pair(producer, consumer)); + if (cache_it != result_cache->end()) { + return cache_it->second; + } + bool result = true; for (int64 i = 0, e = consumer->operand_count(); i < e; ++i) { auto* consumer_operand = consumer->mutable_operand(i); // If the operand is not on a path to the producer, it doesn't matter - // whether it's fusable. + // whether it's fusible. if (!reachability_->IsReachable(producer, consumer_operand)) { continue; } - if (do_not_duplicate.count(consumer_operand) > 0 || - !ShouldFuse(consumer, i)) { - return false; + if (do_not_fuse.count(consumer_operand) > 0 || !ShouldFuse(consumer, i)) { + result = false; + break; } // The producer is reachable from consumer_operand which means we need // to be able to fuse consumer_operand into consumer in order for - // producer to be fusable into consumer on all paths. + // producer to be fusible into consumer on all paths. // Perform the recursive step: make sure producer can be fused into // consumer_operand on all paths. - if (!CanFuseOnAllPaths(producer, consumer_operand, do_not_duplicate)) { - return false; + if (!CanFuseOnAllPaths(producer, consumer_operand, do_not_fuse, + result_cache)) { + result = false; + break; } } - return true; + result_cache->emplace(std::make_pair(producer, consumer), result); + return result; } InstructionFusion::HloInstructionSet -InstructionFusion::ComputeGloballyUnfusable( - tensorflow::gtl::ArraySlice post_order) { +InstructionFusion::ComputeGloballyUnfusible( + absl::Span post_order) { // Forbid fusion of producers that: // a) Need to be duplicated, unless they can be fused into all consumers // via all paths. @@ -228,6 +242,8 @@ InstructionFusion::ComputeGloballyUnfusable( // fusing operations that require duplication later depending on // is_expensive_(). HloInstructionSet do_not_duplicate; + absl::flat_hash_map, bool> + can_fuse_on_all_paths_result_cache; for (HloInstruction* consumer : post_order) { for (HloInstruction* producer : consumer->operands()) { if (do_not_duplicate.count(producer) > 0) { @@ -270,20 +286,21 @@ InstructionFusion::ComputeGloballyUnfusable( // all of its consumers on all paths. // // That means, that for: - // A --> B (fusable) - // \-> C (non-fusable) + // A --> B (fusible) + // \-> C (non-fusible) // A will be not allowed to be fused into B, as it cannot be fused into C. // // Similarly, for: // A -------------> B // \-> C -> D -/ // If: - // - A is fusable into B and C, and D is fusable into B - // - C is *not* fusable into D + // - A is fusible into B and C, and D is fusible into B + // - C is *not* fusible into D // A will be not allowed to be fused into B, as it cannot be fused via // all paths. - if (producer->IsFusable() && - CanFuseOnAllPaths(producer, consumer, do_not_duplicate)) { + if (producer->IsFusible() && + CanFuseOnAllPaths(producer, consumer, do_not_duplicate, + &can_fuse_on_all_paths_result_cache)) { continue; } do_not_duplicate.insert(producer); @@ -293,6 +310,138 @@ InstructionFusion::ComputeGloballyUnfusable( return do_not_duplicate; } +namespace { + +// A FusionQueue that uses reverse post order. +// +// We want to be able to remove arbitrary instructions from the post order and +// also compare positions of instructions in the post order. To make this +// possible, create vector of instructions in post order and create a map from +// HloInstruction* to the instruction's index in the vector. An instruction is +// "removed" from the vector by setting it's element to nullptr. +class ReversePostOrderFusionQueue : public FusionQueue { + public: + explicit ReversePostOrderFusionQueue(HloComputation* computation) { + post_order_ = computation->MakeInstructionPostOrder(); + + for (size_t i = 0; i < post_order_.size(); ++i) { + InsertOrDie(&post_order_index_, post_order_[i], i); + } + } + + std::pair> + DequeueNextInstructionAndOperandsToFuseInOrder() override { + // Instructions are "removed" from the post order by nulling out the element + // in the vector, so if the pointer is null, continue to the next + // instruction in the sort. + while (!post_order_.empty() && post_order_.back() == nullptr) { + post_order_.pop_back(); + } + if (post_order_.empty()) { + return std::pair>{nullptr, {}}; + } + // We want to iterate in reverse post order, so remove from the back of the + // vector. + HloInstruction* instruction = post_order_.back(); + post_order_.pop_back(); + + CHECK(instruction != nullptr); + // Remove instruction from the index map to ensure the vector and map stay + // consistent. + post_order_index_.erase(instruction); + + // Consider each operand of this instruction for fusion into this + // instruction. We want to consider the operands in a particular order to + // avoid creating duplicate instruction clones in the fusion instruction. + // For example, consider the following expression: + // + // A = ... + // B = op(A) + // C = op(A, B) + // + // If we are considering the operands of C for fusion into C. We might + // fuse A or B first. If we fuse A first, we get: + // + // A = ... + // B = op(A) + // C_fusion = { A' = ... + // C' = op(A', B) } + // + // Where A' and C' are clones of A and C, respectively. Now only B is an + // operand of the fusion instruction C_fusion, so then we fuse B: + // + // A = ... + // B = op(A) + // C_fusion = { A' = ... + // B' = op(A) + // C' = op(A', B') } + // + // Now A is an operand of C_fusion again, so we then fuse A (again!): + // + // A = ... + // B = op(A) + // C_fusion = { A' = ... + // A" = .. + // B' = op(A") + // C' = op(A', B') } + // + // We prevent this duplication by considering the operands in the order + // they appear int the queue. In the example, this ensures that B will be + // considered before A. + // + // We store the original indices of the operands to pass to ShouldFuse. + std::vector sorted_operand_numbers; + sorted_operand_numbers.reserve(instruction->operands().size()); + for (int i = 0; i < instruction->operands().size(); ++i) { + // This will happen if we have two possible instructions to fuse the + // same operand into; once the operand is fused into one instruction, + // the other instruction will get a new get-tuple-element as its + // operand, which is not in the queue. + // TODO(tjoerg): Look into fusing past these multi-output fuse points. + if (!ContainsKey(post_order_index_, instruction->mutable_operand(i))) { + continue; + } + sorted_operand_numbers.push_back(i); + } + std::sort( + sorted_operand_numbers.begin(), sorted_operand_numbers.end(), + [&](int64 i, int64 j) { + // Instructions with higher priority in the queue come first. + return ( + FindOrDie(post_order_index_, instruction->mutable_operand(i)) > + FindOrDie(post_order_index_, instruction->mutable_operand(j))); + }); + return std::make_pair(instruction, sorted_operand_numbers); + } + + void OnFusingInstruction(HloInstruction* fusion, + HloInstruction* original_producer, + HloInstruction* original_consumer) override { + // Fusing an instruction into a fusion instruction can change the operand + // set of the fusion instruction. For simplicity just re-enqueue the + // instruction and reconsider it for further fusion in the next iteration. + InsertOrDie(&post_order_index_, fusion, post_order_.size()); + post_order_.push_back(fusion); + } + + void RemoveInstruction(HloInstruction* instruction) override { + post_order_[FindOrDie(post_order_index_, instruction)] = nullptr; + post_order_index_.erase(instruction); + } + + private: + std::vector post_order_; + absl::flat_hash_map post_order_index_; +}; + +} // namespace + +std::unique_ptr InstructionFusion::GetFusionQueue( + HloComputation* computation, + const std::function& skip_producer) { + return absl::make_unique(computation); +} + StatusOr InstructionFusion::Run(HloModule* module) { VLOG(2) << "Before instruction fusion:"; XLA_VLOG_LINES(2, module->ToString()); @@ -304,116 +453,36 @@ StatusOr InstructionFusion::Run(HloModule* module) { computation_ = computation; reachability_ = computation_->ComputeReachability(); - // We want to be able to remove arbitrary instructions from the post order - // and also compare positions of instructions in the post order. To make - // this possible, create vector of instructions in post order and create a - // map from HloInstruction* to the instruction's index in the vector. An - // instruction is "removed" from the vector by setting it's element to - // nullptr. - std::vector post_order = - computation_->MakeInstructionPostOrder(); - - tensorflow::gtl::FlatMap post_order_index; - for (size_t i = 0; i < post_order.size(); ++i) { - InsertOrDie(&post_order_index, post_order[i], i); - } - - HloInstructionSet do_not_duplicate = ComputeGloballyUnfusable(post_order); + HloInstructionSet do_not_duplicate = + ComputeGloballyUnfusible(computation_->MakeInstructionPostOrder()); + auto fusion_queue = + GetFusionQueue(computation_, [&](HloInstruction* producer) { + return do_not_duplicate.count(producer) > 0; + }); // Instruction fusion effectively fuses edges in the computation graph // (producer instruction -> consumer instruction) so we iterate over all // edges. When we fuse an edge, we create a copy of the producer inside the // fusion instruction. - while (!post_order.empty()) { - // We want to iterate in reverse post order, so remove from the back of - // the vector. - HloInstruction* instruction = post_order.back(); - post_order.pop_back(); - - // Instructions are "removed" from the post order by nulling out the - // element in the vector, so if the pointer is null, continue to the next - // instruction in the sort. + while (true) { + auto next_entry = + fusion_queue->DequeueNextInstructionAndOperandsToFuseInOrder(); + auto instruction = next_entry.first; if (instruction == nullptr) { - continue; + break; } - // Remove instruction from the index map to ensure the vector and map stay - // consistent. - post_order_index.erase(instruction); - - if (!instruction->IsFusable() && + if (!instruction->IsFusible() && instruction->opcode() != HloOpcode::kFusion) { continue; } - // Consider each operand of this instruction for fusion into this - // instruction. We want to consider the operands in a particular order to - // avoid creating duplicate instruction clones in the fusion instruction. - // For example, consider the following expression: - // - // A = ... - // B = op(A) - // C = op(A, B) - // - // If we are considering the operands of C for fusion into C. We might - // fuse A or B first. If we fuse A first, we get: - // - // A = ... - // B = op(A) - // C_fusion = { A' = ... - // C' = op(A', B) } - // - // Where A' and C' are clones of A and C, respectively. Now only B is an - // operand of the fusion instruction C_fusion, so then we fuse B: - // - // A = ... - // B = op(A) - // C_fusion = { A' = ... - // B' = op(A) - // C' = op(A', B') } - // - // Now A is an operand of C_fusion again, so we then fuse A (again!): - // - // A = ... - // B = op(A) - // C_fusion = { A' = ... - // A" = .. - // B' = op(A") - // C' = op(A', B') } - // - // We prevent this duplication by considering the operands in the reverse - // order they appear in the instruction post order. In the example, this - // ensures that B will be considered before A. - // - // We store the original indices of the operands to pass to ShouldFuse. - std::vector sorted_operand_numbers; - sorted_operand_numbers.reserve(instruction->operands().size()); - for (int i = 0; i < instruction->operands().size(); ++i) { - // This will happen if we have two possible instructions to fuse the - // same operand into; once the operand is fused into one instruction, - // the other instruction will get a new get-tuple-element as its - // operand, which is not in the post-order index. - // TODO(tjoerg): Look into fusing past these multi-output fuse points. - if (post_order_index.find(instruction->mutable_operand(i)) == - post_order_index.end()) { - continue; - } - sorted_operand_numbers.push_back(i); - } - std::sort( - sorted_operand_numbers.begin(), sorted_operand_numbers.end(), - [&](int64 i, int64 j) { - // Instructions with higher indices in the post order come - // first. - return ( - FindOrDie(post_order_index, instruction->mutable_operand(i)) > - FindOrDie(post_order_index, instruction->mutable_operand(j))); - }); + std::vector& sorted_operand_numbers = next_entry.second; for (int64 i : sorted_operand_numbers) { HloInstruction* operand = instruction->mutable_operand(i); - if (!operand->IsFusable()) { + if (!operand->IsFusible()) { continue; } @@ -423,32 +492,31 @@ StatusOr InstructionFusion::Run(HloModule* module) { // TODO(tjoerg): Consider making multi-output fusion the default. if (ShouldFuse(instruction, i) && do_not_duplicate.count(operand) == 0) { + fusion_queue->PreFusion(operand, instruction); fusion_instruction = Fuse(operand, instruction); } else if (ShouldFuseIntoMultiOutput(instruction, i) && !MultiOutputFusionCreatesCycle(operand, instruction)) { + fusion_queue->PreFusion(operand, instruction); fusion_instruction = FuseIntoMultiOutput(operand, instruction); } else { continue; } - // Fusing an instruction into a fusion instruction can change the - // operand set of the fusion instruction. For simplicity just push the - // instruction to the top of the post_order and reconsider it for - // further fusion in the next iteration of the outer loop. - post_order.push_back(fusion_instruction); - InsertOrDie(&post_order_index, fusion_instruction, - post_order.size() - 1); + fusion_queue->OnFusingInstruction(fusion_instruction, operand, + instruction); changed = true; if (operand->user_count() == 0) { - // Operand is now dead. Remove from post order by setting its - // location to nullptr. - post_order[FindOrDie(post_order_index, operand)] = nullptr; - post_order_index.erase(operand); - + do_not_duplicate.erase(operand); + // Operand is now dead. Remove from queue. + fusion_queue->RemoveInstruction(operand); // Remove from computation. TF_RETURN_IF_ERROR(computation_->RemoveInstruction(operand)); } + + if (fusion_instruction != instruction) { + do_not_duplicate.erase(instruction); + } break; } } diff --git a/tensorflow/compiler/xla/service/instruction_fusion.h b/tensorflow/compiler/xla/service/instruction_fusion.h index 8489c3d9ad21e8fdf26f6323e476ae232c8fcf84..f14c6675208c72112aea0179c238b58709d625b5 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.h +++ b/tensorflow/compiler/xla/service/instruction_fusion.h @@ -1,3 +1,4 @@ +#include "absl/container/flat_hash_map.h" /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); @@ -16,6 +17,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_INSTRUCTION_FUSION_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_INSTRUCTION_FUSION_H_ +#include "tensorflow/compiler/xla/service/fusion_queue.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" @@ -29,7 +31,7 @@ namespace xla { // with the intent that the loops which compute their values will be fused in // code generation. Derived classes define ShouldFuse method to select which // instructions to fuse. -class InstructionFusion : public HloPassInterface { +class InstructionFusion : public HloModulePass { public: explicit InstructionFusion( std::function is_expensive, @@ -48,6 +50,13 @@ class InstructionFusion : public HloPassInterface { static bool IsExpensive(const HloInstruction& instruction); protected: + // Returns a FusionQueue that implements custom order of instructions being + // fused. The default implementation processes consumers in reverse post + // order. + virtual std::unique_ptr GetFusionQueue( + HloComputation* computation, + const std::function& skip_producer); + // Returns whether the given producer instruction should be fused into the // given consumer instruction. producer is necessarily an operand of consumer. // Derived classes should define this method to specify which instructions @@ -117,13 +126,20 @@ class InstructionFusion : public HloPassInterface { // Whether or not we can fuse producer into consumer on all paths // from the producer to the consumer where nodes are HLOs and edges are uses. - bool CanFuseOnAllPaths(HloInstruction* producer, HloInstruction* consumer, - const HloInstructionSet& do_not_fuse); + // + // A map from to a bool is required as the result cache + // to store and query the results of calls to this function, in order to avoid + // repeated computations. + bool CanFuseOnAllPaths( + HloInstruction* producer, HloInstruction* consumer, + const HloInstructionSet& do_not_fuse, + absl::flat_hash_map, bool>* + result_cache); // Computes the set of nodes that we do not want to fuse into any of their // consumers based on a global analysis of the HLO graph. - HloInstructionSet ComputeGloballyUnfusable( - tensorflow::gtl::ArraySlice post_order); + HloInstructionSet ComputeGloballyUnfusible( + absl::Span post_order); // Used to determine if an HLO is expensive. Expensive operations will not be // duplicated. diff --git a/tensorflow/compiler/xla/service/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/instruction_fusion_test.cc index 9e7a15f0330d3f06779c850a4b575f84fe0b9505..da1ad90959dc0ab1a840b3390281ce9d4999651e 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion_test.cc @@ -158,7 +158,7 @@ TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfParameterUnfused) { .ValueOrDie()); } -TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusable) { +TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusible) { HloComputation::Builder builder(TestName()); auto shape = ShapeUtil::MakeShape(F32, {16, 16}); auto param0 = @@ -216,7 +216,7 @@ TEST_F(InstructionFusionTest, FuseCheapNonDuplicatableOps) { EXPECT_EQ(Count(*module, HloOpcode::kAdd), 1) << module->ToString(); } -TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusableRecursively) { +TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusibleRecursively) { // Make sure we do not duplicate the add, as we cannot fuse through the rng. // // p0 -> add -------------------------> sub @@ -309,7 +309,7 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusableRecursively) { EXPECT_EQ(Count(*module, HloOpcode::kAdd), 2) << module->ToString(); // A variant of the above that allows the algorithm to put add2 into the set - // of unfusable ops to short-circuit the decision whether add1 should be fused + // of unfusible ops to short-circuit the decision whether add1 should be fused // into sub2. // // /---------------\ diff --git a/tensorflow/compiler/xla/service/interpreter/BUILD b/tensorflow/compiler/xla/service/interpreter/BUILD index 581f8d2e92b9d7c4350360282cbd9e69824841ca..1484e14df10d94841c5a2e849761779f5800392d 100644 --- a/tensorflow/compiler/xla/service/interpreter/BUILD +++ b/tensorflow/compiler/xla/service/interpreter/BUILD @@ -45,8 +45,8 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_pass", "//tensorflow/compiler/xla/service:hlo_pass_pipeline", "//tensorflow/compiler/xla/service:hlo_subcomputation_unification", - "//tensorflow/compiler/xla/service:inliner", "//tensorflow/compiler/xla/service:layout_assignment", + "//tensorflow/compiler/xla/service:map_inliner", "//tensorflow/compiler/xla/service:reshape_mover", "//tensorflow/compiler/xla/service:while_loop_simplifier", "//tensorflow/core:lib", @@ -89,6 +89,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "@com_google_absl//absl/memory", + "@com_google_absl//absl/types:span", ], ) @@ -114,5 +115,6 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_headers_lib", + "@com_google_absl//absl/types:span", ], ) diff --git a/tensorflow/compiler/xla/service/interpreter/compiler.cc b/tensorflow/compiler/xla/service/interpreter/compiler.cc index bb69cb9c47ff2c7de8d13832c4b8e6216c62da73..26643667c8674c85e5d03da4c5a2d63833e1d27f 100644 --- a/tensorflow/compiler/xla/service/interpreter/compiler.cc +++ b/tensorflow/compiler/xla/service/interpreter/compiler.cc @@ -28,9 +28,9 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_pass_fix.h" #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h" #include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h" -#include "tensorflow/compiler/xla/service/inliner.h" #include "tensorflow/compiler/xla/service/interpreter/executable.h" #include "tensorflow/compiler/xla/service/layout_assignment.h" +#include "tensorflow/compiler/xla/service/map_inliner.h" #include "tensorflow/compiler/xla/service/reshape_mover.h" #include "tensorflow/compiler/xla/service/while_loop_simplifier.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -44,7 +44,8 @@ Status InterpreterCompiler::RunHloOptimization(HloModule* hlo_module) { HloPassPipeline pipeline("Interpreter"); pipeline.AddPass( - hlo_module->mutable_entry_computation_layout()); + hlo_module->mutable_entry_computation_layout(), + LayoutAssignment::InstructionCanChangeLayout); return pipeline.Run(hlo_module).status(); } @@ -56,6 +57,12 @@ StatusOr> InterpreterCompiler::RunHloPasses( return std::move(hlo_module); } +Status InterpreterCompiler::RunHloPassesOnModuleGroup( + HloModuleGroup* module_group, se::StreamExecutor* executor, + DeviceMemoryAllocator* device_allocator) { + return Unimplemented("Module group compilation not supported on Interpreter"); +} + StatusOr> InterpreterCompiler::RunBackend( std::unique_ptr hlo_module, se::StreamExecutor* stream_exec, DeviceMemoryAllocator* /*device_allocator*/) { @@ -75,17 +82,26 @@ StatusOr> InterpreterCompiler::RunBackend( return std::move(executable); } +StatusOr>> +InterpreterCompiler::RunBackendOnModuleGroup( + std::unique_ptr module_group, + std::vector> stream_exec, + DeviceMemoryAllocator* device_allocator) { + return Unimplemented( + "Module group compilation is not supported on Interpreter."); +} + StatusOr>> InterpreterCompiler::Compile( - std::vector> /*hlo_modules*/, + std::unique_ptr /*module_group*/, std::vector> /*stream_execs*/, DeviceMemoryAllocator* /*device_allocator*/) { - return tensorflow::errors::Unimplemented( - "Compilation of multiple HLO modules is not supported on Interpreter."); + return Unimplemented( + "Module group compilation is not supported on Interpreter."); } StatusOr>> InterpreterCompiler::CompileAheadOfTime( - std::vector> hlo_modules, + std::unique_ptr module_group, const AotCompilationOptions& aot_options) { return tensorflow::errors::InvalidArgument( "AOT compilation not supported on Interpreter"); diff --git a/tensorflow/compiler/xla/service/interpreter/compiler.h b/tensorflow/compiler/xla/service/interpreter/compiler.h index e90ae3e818522e6e4fd9d9f5acb846800bc899ca..d8cb32c0beb279ae6484b1b8f5f99085c2d67c67 100644 --- a/tensorflow/compiler/xla/service/interpreter/compiler.h +++ b/tensorflow/compiler/xla/service/interpreter/compiler.h @@ -46,18 +46,25 @@ class InterpreterCompiler : public Compiler { StatusOr> RunHloPasses( std::unique_ptr hlo_module, se::StreamExecutor* stream_exec, DeviceMemoryAllocator* device_allocator) override; + Status RunHloPassesOnModuleGroup( + HloModuleGroup* module_group, se::StreamExecutor* executor, + DeviceMemoryAllocator* device_allocator) override; StatusOr> RunBackend( std::unique_ptr hlo_module, se::StreamExecutor* stream_exec, DeviceMemoryAllocator* device_allocator) override; + StatusOr>> RunBackendOnModuleGroup( + std::unique_ptr module_group, + std::vector> stream_exec, + DeviceMemoryAllocator* device_allocator) override; StatusOr>> Compile( - std::vector> hlo_modules, + std::unique_ptr module_group, std::vector> stream_exec, DeviceMemoryAllocator* device_allocator) override; StatusOr>> - CompileAheadOfTime(std::vector> hlo_modules, + CompileAheadOfTime(std::unique_ptr module_group, const AotCompilationOptions& aot_options) override; HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const override; diff --git a/tensorflow/compiler/xla/service/interpreter/executable.cc b/tensorflow/compiler/xla/service/interpreter/executable.cc index 2259dc1083e6d1ca64cc7d7b8d9c566a27183ac7..a06d6113e84630df14ff68280c248cccb9afaf06 100644 --- a/tensorflow/compiler/xla/service/interpreter/executable.cc +++ b/tensorflow/compiler/xla/service/interpreter/executable.cc @@ -47,7 +47,7 @@ InterpreterExecutable::~InterpreterExecutable() {} StatusOr InterpreterExecutable::ExecuteOnStream( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, HloExecutionProfile* hlo_execution_profile) { se::Stream* stream = run_options->stream(); se::StreamExecutor* executor = stream->parent(); @@ -73,30 +73,29 @@ StatusOr InterpreterExecutable::ExecuteOnStream( // Transform the ShapedBuffer arguments into literals which the evaluator // consumes. - std::vector> arg_literals; + std::vector arg_literals; for (int64 p = 0; p < computation->num_parameters(); ++p) { - TF_ASSIGN_OR_RETURN(std::unique_ptr arg_literal, + TF_ASSIGN_OR_RETURN(Literal arg_literal, transfer_manager->TransferLiteralFromDevice( run_options->stream(), *arguments[p])); arg_literals.push_back(std::move(arg_literal)); } // Execute the graph using the HloEvaluator. - std::unique_ptr result_literal; + Literal result_literal; { tensorflow::mutex_lock lock(evaluator_lock_); - TF_ASSIGN_OR_RETURN(result_literal, - evaluator_->Evaluate>( - *computation, arg_literals)); + TF_ASSIGN_OR_RETURN(result_literal, evaluator_->Evaluate( + *computation, arg_literals)); } // Transform the result literal back into a ShapedBuffer. TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result, transfer_manager->AllocateScopedShapedBuffer( - result_literal->shape(), run_options->allocator(), + result_literal.shape(), run_options->allocator(), executor->device_ordinal())); TF_RETURN_IF_ERROR(transfer_manager->TransferLiteralToDevice( - run_options->stream(), *result_literal, result)); + run_options->stream(), result_literal, result)); uint64 end_micros = tensorflow::Env::Default()->NowMicros(); @@ -111,7 +110,7 @@ StatusOr InterpreterExecutable::ExecuteOnStream( StatusOr InterpreterExecutable::ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments) { + absl::Span arguments) { return tensorflow::errors::Unimplemented( "ExecuteAsyncOnStream is not yet supported on Interpreter."); } diff --git a/tensorflow/compiler/xla/service/interpreter/executable.h b/tensorflow/compiler/xla/service/interpreter/executable.h index 91d8148d26dc8eddbafdaf4870d9efbb73a12816..3b1ebce0c75457d65e6834c809fe488a9c4a159a 100644 --- a/tensorflow/compiler/xla/service/interpreter/executable.h +++ b/tensorflow/compiler/xla/service/interpreter/executable.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/service/executable.h" #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" #include "tensorflow/compiler/xla/service/hlo_evaluator.h" @@ -29,7 +30,6 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -48,13 +48,13 @@ class InterpreterExecutable : public Executable { StatusOr ExecuteOnStream( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, HloExecutionProfile* hlo_execution_profile) override LOCKS_EXCLUDED(evaluator_lock_); StatusOr ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments) override; + absl::Span arguments) override; static int64 ShapeSizeBytes(const Shape& shape); diff --git a/tensorflow/compiler/xla/service/interpreter/executor.h b/tensorflow/compiler/xla/service/interpreter/executor.h index db6b910b32f8ec234c4cf1c331a1aa3bb2f9389f..fbb99457847dca69a1901006d5d8ff713882f918 100644 --- a/tensorflow/compiler/xla/service/interpreter/executor.h +++ b/tensorflow/compiler/xla/service/interpreter/executor.h @@ -22,9 +22,9 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/stream_executor/blas.h" #include "tensorflow/stream_executor/device_description.h" @@ -47,7 +47,7 @@ limitations under the License. namespace stream_executor { namespace interpreter { -using Args = tensorflow::gtl::ArraySlice; +using Args = absl::Span; class XlaInterpreterExecutor : public internal::StreamExecutorInterface { public: diff --git a/tensorflow/compiler/xla/service/interpreter/platform.cc b/tensorflow/compiler/xla/service/interpreter/platform.cc index e57a9b3672391e11b130b1c16307a80a0a5b5e77..c9b40d3c6195f80a19272a0d98890049d02315b9 100644 --- a/tensorflow/compiler/xla/service/interpreter/platform.cc +++ b/tensorflow/compiler/xla/service/interpreter/platform.cc @@ -18,13 +18,13 @@ limitations under the License. #include #include "absl/memory/memory.h" +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/service/interpreter/executor.h" #include "tensorflow/stream_executor/device_options.h" #include "tensorflow/stream_executor/lib/initialize.h" #include "tensorflow/stream_executor/lib/ptr_util.h" #include "tensorflow/stream_executor/lib/status.h" #include "tensorflow/stream_executor/lib/status_macros.h" -#include "tensorflow/stream_executor/lib/stringprintf.h" #include "tensorflow/stream_executor/multi_platform_manager.h" #include "tensorflow/stream_executor/platform.h" @@ -77,9 +77,9 @@ XlaInterpreterPlatform::GetUncachedExecutor( if (!init_status.ok()) { return port::Status{ port::error::INTERNAL, - port::Printf( + absl::StrFormat( "failed initializing StreamExecutor for device ordinal %d: %s", - config.ordinal, init_status.ToString().c_str())}; + config.ordinal, init_status.ToString())}; } return std::move(executor); diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc index 5741864282ec4722fc961496969ac5f47aa6200f..232d1dc0879cd6931158e642e01fe68e43e6c655 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/layout_assignment.cc @@ -28,7 +28,9 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "absl/strings/str_join.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/computation_layout.h" @@ -50,8 +52,6 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" @@ -71,9 +71,8 @@ BufferLayoutConstraint::BufferLayoutConstraint(const Layout& layout, } string BufferLayoutConstraint::ToString() const { - return tensorflow::strings::Printf("BufferLayoutConstraint %s: %s", - buffer_->ToString().c_str(), - LayoutUtil::HumanString(layout_).c_str()); + return absl::StrFormat("BufferLayoutConstraint %s: %s", buffer_->ToString(), + LayoutUtil::HumanString(layout_)); } OperandLayoutConstraint::OperandLayoutConstraint( @@ -92,15 +91,14 @@ OperandLayoutConstraint::OperandLayoutConstraint( } string OperandLayoutConstraint::ToString() const { - return tensorflow::strings::Printf( - "OperandLayoutConstraint %s, operand %lld: %s", - instruction_->name().c_str(), operand_no_, - shape_layout_.ToString().c_str()); + return absl::StrFormat("OperandLayoutConstraint %s, operand %d: %s", + instruction_->name(), operand_no_, + shape_layout_.ToString()); } string ResultLayoutConstraint::ToString() const { - return tensorflow::strings::Printf("ResultLayoutConstraint: %s", - shape_layout_.ToString().c_str()); + return absl::StrFormat("ResultLayoutConstraint: %s", + shape_layout_.ToString()); } LayoutConstraints::LayoutConstraints( @@ -168,8 +166,7 @@ Status LayoutConstraints::SetBufferLayout(const Layout& layout, return FailedPrecondition( "Layout of buffer %s cannot be constrained because buffer is not " "array-shaped, has shape: %s", - buffer.ToString().c_str(), - ShapeUtil::HumanString(buffer.shape()).c_str()); + buffer.ToString(), ShapeUtil::HumanString(buffer.shape())); } TF_RETURN_IF_ERROR( LayoutUtil::ValidateLayoutForShape(layout, buffer.shape())); @@ -185,9 +182,8 @@ Status LayoutConstraints::SetBufferLayout(const Layout& layout, return FailedPrecondition( "Buffer %s already has the layout constraint %s, cannot add " "incompatible constraint %s", - buffer.ToString().c_str(), - LayoutUtil::HumanString(curr_constraint.layout()).c_str(), - LayoutUtil::HumanString(layout).c_str()); + buffer.ToString(), LayoutUtil::HumanString(curr_constraint.layout()), + LayoutUtil::HumanString(layout)); } iter->second = BufferLayoutConstraint(layout, buffer, mandatory, dfs); } else { @@ -221,11 +217,11 @@ Status LayoutConstraints::SetOperandLayout(const Shape& shape_with_layout, } if (curr_shape_layout->mandatory()) { return FailedPrecondition( - "Operand %lld of instruction %s already has a layout constraint " + "Operand %d of instruction %s already has a layout constraint " "%s, cannot add incompatible constraint %s", - operand_no, instruction->name().c_str(), - curr_shape_layout->shape_layout().ToString().c_str(), - ShapeUtil::HumanStringWithLayout(shape_with_layout).c_str()); + operand_no, instruction->name(), + curr_shape_layout->shape_layout().ToString(), + ShapeUtil::HumanStringWithLayout(shape_with_layout)); } } @@ -234,9 +230,9 @@ Status LayoutConstraints::SetOperandLayout(const Shape& shape_with_layout, // layouts beyond this immediate use and is complicated to handle. if (OperandBufferForwarded(instruction, operand_no)) { return FailedPrecondition( - "Cannot constraint layout of operand %lld of instruction %s " + "Cannot constraint layout of operand %d of instruction %s " "because instruction forwards operand's LogicalBuffer(s)", - operand_no, instruction->name().c_str()); + operand_no, instruction->name()); } auto key = std::make_pair(instruction, operand_no); @@ -278,8 +274,8 @@ Status LayoutConstraints::SetResultLayout(const Shape& shape_with_layout, return FailedPrecondition( "Result of computation %s already has the layout constraint %s, " "cannot add incompatible constraint %s", - computation_->name().c_str(), curr_shape_layout->ToString().c_str(), - ShapeUtil::HumanStringWithLayout(shape_with_layout).c_str()); + computation_->name(), curr_shape_layout->ToString(), + ShapeUtil::HumanStringWithLayout(shape_with_layout)); } // New constraint matches existing constraint. Nothing to do. return Status::OK(); @@ -301,9 +297,8 @@ Status LayoutConstraints::SetInstructionLayout( if (!ShapeUtil::Compatible(shape_with_layout, instruction->shape())) { return FailedPrecondition( "Instruction %s of shape %s cannot be assigned incompatible layout %s", - instruction->name().c_str(), - ShapeUtil::HumanString(instruction->shape()).c_str(), - ShapeUtil::HumanStringWithLayout(shape_with_layout).c_str()); + instruction->name(), ShapeUtil::HumanString(instruction->shape()), + ShapeUtil::HumanStringWithLayout(shape_with_layout)); } // Create a BufferLayoutConstraint for each array shape in the output of the @@ -424,6 +419,16 @@ Status LayoutAssignment::BuildHostChannelConstraints( return Status::OK(); } +namespace { + +bool IsLayoutConstrainedCustomCall(HloInstruction* instruction) { + const HloCustomCallInstruction* custom_call = + DynCast(instruction); + return custom_call != nullptr && custom_call->layout_constrained(); +} + +} // namespace + Status LayoutAssignment::AddMandatoryConstraints( const ComputationLayout* computation_layout, ChannelLayoutConstraints* channel_constraints, HloComputation* computation, @@ -439,7 +444,6 @@ Status LayoutAssignment::AddMandatoryConstraints( // Constrain layouts of instructions which define values with pre-existing // layouts. for (auto* instruction : computation->instructions()) { - Shape const* shape_with_layout = nullptr; if (instruction->opcode() == HloOpcode::kInfeed) { // Infeed layouts must match the layout of the original inserted // instruction. @@ -461,17 +465,21 @@ Status LayoutAssignment::AddMandatoryConstraints( if (parameter_layout.LayoutIsSet()) { // Parameter layouts must match the respective layout in // ComputationLayout, if there is one. - shape_with_layout = ¶meter_layout.shape(); + TF_RETURN_IF_ERROR(constraints->SetInstructionLayout( + parameter_layout.shape(), instruction)); } } - } - if (shape_with_layout != nullptr) { + } else if (IsLayoutConstrainedCustomCall(instruction)) { + const HloCustomCallInstruction* custom_call = + DynCast(instruction); TF_RETURN_IF_ERROR( - constraints->SetInstructionLayout(*shape_with_layout, instruction)); - } - - if (instruction->opcode() == HloOpcode::kSend || - instruction->opcode() == HloOpcode::kRecv) { + constraints->SetInstructionLayout(custom_call->shape(), custom_call)); + for (int64 i = 0; i < custom_call->operand_count(); ++i) { + TF_RETURN_IF_ERROR(constraints->SetOperandLayout( + custom_call->operand_shapes_with_layout()[i], custom_call, i)); + } + } else if (instruction->opcode() == HloOpcode::kSend || + instruction->opcode() == HloOpcode::kRecv) { CHECK(get_channel_constraints(instruction)) << "Multi-module layout assignment requires ChannelLayoutConstraints"; int64 channel_id = instruction->channel_id(); @@ -503,6 +511,22 @@ Status LayoutAssignment::AddMandatoryConstraints( TF_RETURN_IF_ERROR( constraints->SetBufferLayout(new_shape.layout(), *buffer)); } + } else if (instruction->IsCrossModuleAllReduce()) { + CHECK(get_channel_constraints(instruction)) + << "Multi-module layout assignment requires ChannelLayoutConstraints"; + int64 all_reduce_id = instruction->all_reduce_id().value(); + if (!get_channel_constraints(instruction) + ->IsChannelConstrained(all_reduce_id)) { + continue; + } + // TODO(b/68493863): Change to use SetOperandLayout(). + const Shape& buffer_shape = instruction->operand(0)->shape(); + TF_RET_CHECK(ShapeUtil::IsArray(buffer_shape)); + Shape new_buffer_shape = + get_channel_constraints(instruction) + ->LayoutShapeForChannel(buffer_shape, all_reduce_id); + TF_RETURN_IF_ERROR( + constraints->SetInstructionLayout(new_buffer_shape, instruction)); } } @@ -610,31 +634,6 @@ Status LayoutAssignment::AddMandatoryConstraints( TF_RETURN_IF_ERROR(constraints->SetOperandLayout( false_computation_layout.parameter_shape(0), instruction, 2, /*mandatory=*/true)); - } else if (instruction->opcode() == HloOpcode::kCustomCall) { - if (!CustomCallRequiresMajorFirstLayout(instruction)) { - continue; - } - // Add constraints for kCustomCall instruction operands and instructions. - // For now we only support major-first layouts for all inputs and outputs. - Shape result_shape = ShapeUtil::MakeShapeWithDescendingLayout( - instruction->shape().element_type(), - AsInt64Slice(instruction->shape().dimensions())); - TF_RETURN_IF_ERROR( - constraints->SetInstructionLayout(result_shape, instruction)); - for (int64 i = 0; i < instruction->operand_count(); ++i) { - const Shape& operand_shape = instruction->operand(i)->shape(); - // Opaque operands don't get a layout constraint. - if (ShapeUtil::IsOpaque(operand_shape)) { - continue; - } - - Shape row_major_operand_shape = - ShapeUtil::MakeShapeWithDescendingLayout( - operand_shape.element_type(), - AsInt64Slice(operand_shape.dimensions())); - TF_RETURN_IF_ERROR(constraints->SetOperandLayout( - row_major_operand_shape, instruction, i)); - } } } // Finally set the result layout to match ComputationLayout, if there is one. @@ -665,16 +664,18 @@ Status CheckCallLayout(HloInstruction* call, return Status::OK(); } -// Custom calls have fixed input and output layouts. -Status CheckCustomCallLayout(HloInstruction* custom_call) { - for (const HloInstruction* operand : custom_call->operands()) { - TF_RET_CHECK( - ShapeUtil::IsOpaque(operand->shape()) || - LayoutUtil::IsMonotonicWithDim0Major(operand->shape().layout())); +// Operands of layout-constrained custom calls must match the expected +// constrained layouts. +Status CheckCustomCallLayout(HloInstruction* instruction) { + if (IsLayoutConstrainedCustomCall(instruction)) { + const HloCustomCallInstruction* custom_call = + DynCast(instruction); + for (int64 i = 0; i < custom_call->operand_count(); ++i) { + TF_RET_CHECK(LayoutUtil::LayoutsInShapesEqual( + custom_call->operand(i)->shape(), + custom_call->operand_shapes_with_layout()[i])); + } } - TF_RET_CHECK( - ShapeUtil::IsOpaque(custom_call->shape()) || - LayoutUtil::IsMonotonicWithDim0Major(custom_call->shape().layout())); return Status::OK(); } @@ -753,7 +754,7 @@ Status CheckParameterLayout(HloInstruction* parameter, return InternalError( "parameter instruction %s does not match layout of computation " "shape: %s", - parameter->ToString().c_str(), parameter_layout.ToString().c_str()); + parameter->ToString(), parameter_layout.ToString()); } return Status::OK(); } @@ -764,8 +765,8 @@ Status CheckConstantLayout(HloInstruction* constant) { constant->shape())) { return InternalError( "constant instruction %s does not match the layout of its literal %s", - constant->ToString().c_str(), - ShapeUtil::HumanStringWithLayout(constant->literal().shape()).c_str()); + constant->ToString(), + ShapeUtil::HumanStringWithLayout(constant->literal().shape())); } return Status::OK(); } @@ -781,21 +782,27 @@ StatusOr LayoutAssignment::CreateCopyWithNewLayout( << " instruction: " << instruction->ToString(); if (ShapeUtil::IsTuple(instruction->shape())) { - // Deep-copy tuples. + // Copy tuple elements which have differing layouts. std::vector element_copies; for (int64 i = 0; i < ShapeUtil::TupleElementCount(instruction->shape()); ++i) { + const Shape& target_shape = + ShapeUtil::GetSubshape(shape_with_layout, {i}); + const Shape& instr_shape = + ShapeUtil::GetSubshape(instruction->shape(), {i}); HloInstruction* gte = instruction->parent()->AddInstruction( - HloInstruction::CreateGetTupleElement( - ShapeUtil::GetSubshape(instruction->shape(), {i}), instruction, - i)); - SetupCopiedInstruction(*instruction, gte, {i}); - // Recurse to copy each elements. - TF_ASSIGN_OR_RETURN( - HloInstruction * element_copy, - CreateCopyWithNewLayout( - ShapeUtil::GetSubshape(shape_with_layout, {i}), gte)); - element_copies.push_back(element_copy); + HloInstruction::CreateGetTupleElement(instr_shape, instruction, i)); + + if (ShapeUtil::Equal(target_shape, instr_shape)) { + // Shapes and layouts are equal, no need to copy. + element_copies.push_back(gte); + } else { + SetupCopiedInstruction(*instruction, gte, {i}); + // Recurse to copy each element. + TF_ASSIGN_OR_RETURN(HloInstruction * element_copy, + CreateCopyWithNewLayout(target_shape, gte)); + element_copies.push_back(element_copy); + } } // Gather element copies into a tuple with a new Tuple instruction. HloInstruction* tuple_copy = instruction->parent()->AddInstruction( @@ -860,8 +867,7 @@ void LayoutAssignment::SetupCopiedInstruction(const HloInstruction& instruction, ? instruction.sharding().GetSubSharding(instruction.shape(), index) : instruction.sharding(); // We propagate the sharding to the copied instruction only if it is a - // special sharding, like tiled ones, or special devices like the - // HostCompute module. + // special sharding, like tiled ones. // Otherwise it is preferable to leave the new instruction without device, // and let the automatic device placer to choose the best location. auto device = sharding.UniqueDevice(); @@ -898,13 +904,10 @@ Status LayoutAssignment::CheckLayouts(HloModule* module) { return InternalError( "Layout of instruction %s at index {%s} does not match " "source LogicalBuffer %s: %s vs %s", - instruction->name().c_str(), - absl::StrJoin(index, ",").c_str(), - buffer->ToString().c_str(), - ShapeUtil::HumanStringWithLayout(instruction_subshape) - .c_str(), - ShapeUtil::HumanStringWithLayout(buffer->shape()) - .c_str()); + instruction->name(), absl::StrJoin(index, ","), + buffer->ToString(), + ShapeUtil::HumanStringWithLayout(instruction_subshape), + ShapeUtil::HumanStringWithLayout(buffer->shape())); } } } @@ -919,9 +922,7 @@ Status LayoutAssignment::CheckLayouts(HloModule* module) { FindOrDie(computation_layouts_, instruction->to_apply()))); break; case HloOpcode::kCustomCall: - if (CustomCallRequiresMajorFirstLayout(instruction)) { - TF_RETURN_IF_ERROR(CheckCustomCallLayout(instruction)); - } + TF_RETURN_IF_ERROR(CheckCustomCallLayout(instruction)); break; case HloOpcode::kFusion: TF_RETURN_IF_ERROR(CheckFusionLayout(instruction)); @@ -958,19 +959,23 @@ Status LayoutAssignment::CheckLayouts(HloModule* module) { FindOrDie(computation_layouts_, module->entry_computation()) .result_layout(); if (result_layout.LayoutIsSet()) { - TF_RET_CHECK(ShapeUtil::Equal( - module->entry_computation()->root_instruction()->shape(), - result_layout.shape())); + TF_RET_CHECK( + ShapeUtil::Equal(module->result_shape(), result_layout.shape())); } return Status::OK(); } LayoutAssignment::LayoutAssignment( ComputationLayout* entry_computation_layout, + std::function + instruction_can_change_layout_func, ChannelLayoutConstraints* channel_constraints) : entry_computation_layout_(entry_computation_layout), + saved_entry_computation_layout_(*entry_computation_layout), - channel_layout_constraints_(channel_constraints) { + channel_layout_constraints_(channel_constraints), + instruction_can_change_layout_func_( + std::move(instruction_can_change_layout_func)) { if (channel_layout_constraints_ != nullptr) { // Save a copy of the input ChannelLayoutConstraints so that we can reset it // if we have to undo previous operations (ClearPreviousPassSideEffects()). @@ -988,16 +993,17 @@ std::unique_ptr LayoutAssignment::ChooseOperandLayoutFromOutputLayout( CHECK(ShapeUtil::IsArray(instruction->shape())); CHECK(ShapeUtil::IsArray(operand->shape())); - if (instruction->IsElementwiseOnOperand(operand_no) && - !ShapeUtil::IsScalar(operand->shape()) && + if (!ShapeUtil::IsScalar(operand->shape()) && ShapeUtil::Rank(operand->shape()) == - ShapeUtil::Rank(instruction->shape())) { - // Assign operands the same layout as the instruction, so that + ShapeUtil::Rank(instruction->shape()) && + !instruction_can_change_layout_func_(instruction)) { + // Propagate the result layout to the operand layout if the instruction + // requires the same layout out for the result and the operand. + // + // For elementwise operations, using the same layout for the operands and + // the result also has the following benefits: // 1) the elementwise operation can reuse its operand's buffer, and // 2) the input and output elements can reuse the same linear index. - // - // TODO(jingyue): Other operations, such as kSlice and kConcat, can benefit - // from assigning the same layout to input and output. return absl::make_unique(output_layout); } @@ -1066,9 +1072,9 @@ std::unique_ptr LayoutAssignment::ChooseOutputLayoutFromOperandLayout( CHECK(ShapeUtil::IsArray(user->shape()) && ShapeUtil::IsArray(operand->shape())); - if (user->IsElementwiseOnOperand(operand_no) && - !ShapeUtil::IsScalar(operand->shape()) && - ShapeUtil::Rank(operand->shape()) == ShapeUtil::Rank(user->shape())) { + if (!ShapeUtil::IsScalar(operand->shape()) && + ShapeUtil::Rank(operand->shape()) == ShapeUtil::Rank(user->shape()) && + !instruction_can_change_layout_func_(user)) { // Assign users the same layout as the operand. return absl::make_unique(operand_layout); } @@ -1375,7 +1381,7 @@ StatusOr InferArrayLayout( // This should not happen because we've assigned layouts to all // instructions preceding this one. return InternalError("LogicalBuffer %s does not have a layout", - source_buffer->ToString().c_str()); + source_buffer->ToString()); } if (first_buffer_layout == nullptr) { @@ -1390,9 +1396,8 @@ StatusOr InferArrayLayout( return FailedPrecondition( "Array at index {%s} in instruction %s aliases buffers %s " "and %s which have different layouts", - absl::StrJoin(index, ",").c_str(), instruction->name().c_str(), - source_buffers[0]->ToString().c_str(), - source_buffer->ToString().c_str()); + absl::StrJoin(index, ","), instruction->name(), + source_buffers[0]->ToString(), source_buffer->ToString()); } } @@ -1518,22 +1523,13 @@ Status LayoutAssignment::AssignLayouts(const LayoutConstraints& constraints, // Execute extra verification step once the layout has been finalized. TF_RETURN_IF_ERROR(Verify(instruction)); + // Shape must be valid. + TF_RETURN_IF_ERROR( + ShapeUtil::ValidateShapeWithOptionalLayout(instruction->shape())); + // Verify all layouts in the shape have been set. TF_RET_CHECK(LayoutUtil::HasLayout(instruction->shape())); } - - // Copy the root instruction's result if its layout does not match the result - // layout constraint. - if (constraints.ResultLayout() != nullptr && - !constraints.ResultLayout()->MatchesLayoutInShape( - computation->root_instruction()->shape())) { - TF_ASSIGN_OR_RETURN( - HloInstruction * new_root, - CreateCopyWithNewLayout(constraints.ResultLayout()->shape(), - computation->root_instruction())); - computation->set_root_instruction(new_root); - } - return Status::OK(); } @@ -1549,20 +1545,22 @@ Status LayoutAssignment::CalculateComputationLayout( Status LayoutAssignment::ClearComputationLayouts(HloComputation* computation) { // Clear existing layouts of the instructions. All layouts must be assigned - // by the LayoutAssignment pass, except for those on infeeds, parameters, - // and the computation result. The latter two are specified in - // computation_layout, so we only need to keep the existing layouts for - // infeeds. Clearing the layouts here avoids hiding potential bugs in the - // layout assignment pass that may accidentally use the existing layout. + // by the LayoutAssignment pass, except for those on parameters, the + // computation result, and a couple special cases. The former two are + // specified in computation_layout. Clearing the layouts here avoids hiding + // potential bugs in the layout assignment pass that may accidentally use the + // existing layout. for (HloInstruction* instruction : computation->instructions()) { if (instruction->opcode() == HloOpcode::kBitcast) { // bitcasts are inherently layout sensitive and so a bitcast instruction // present in the IR before layout assignment is a bug. return InternalError( "Unexpected bitcast operation seen during layout assignment: %s.", - instruction->ToString().c_str()); + instruction->ToString()); } - if (instruction->opcode() != HloOpcode::kInfeed) { + // Some instructions carry mandatory layouts in their shape. + if (instruction->opcode() != HloOpcode::kInfeed && + !IsLayoutConstrainedCustomCall(instruction)) { LayoutUtil::ClearLayout(instruction->mutable_shape()); } } @@ -1663,6 +1661,18 @@ Status LayoutAssignment::RunOnComputation( TF_RETURN_IF_ERROR( ConstrainChannelLayouts(computation, channel_constraints)); } + + // Copy the root instruction's result if its layout does not match the result + // layout constraint. + if (constraints.ResultLayout() != nullptr && + !constraints.ResultLayout()->MatchesLayoutInShape( + computation->root_instruction()->shape())) { + TF_ASSIGN_OR_RETURN( + HloInstruction * new_root, + CreateCopyWithNewLayout(constraints.ResultLayout()->shape(), + computation->root_instruction())); + computation->set_root_instruction(new_root); + } return Status::OK(); } @@ -1718,6 +1728,30 @@ Status LayoutAssignment::ConstrainChannelLayouts( ShapeUtil::GetMutableSubshape(instruction->mutable_shape(), {0}); *send_shape = shape; } + } else if (instruction->IsCrossModuleAllReduce()) { + const Layout* layout = + get_channel_constraints(instruction) + ->ConstrainChannel(instruction->all_reduce_id().value(), + instruction->shape().layout()); + if (layout != nullptr) { + // We found an already constrained layout which does not match the one + // the channel wants to impose. Either add a new kCopy, or use the + // existing one to marshal the correct shape. + HloInstruction* operand = instruction->mutable_operand(0); + Shape shape = operand->shape(); + *shape.mutable_layout() = *layout; + if (operand->opcode() != HloOpcode::kCopy) { + HloInstruction* copy = operand->parent()->AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kCopy, operand)); + RegisterAddedCopy(copy); + SetupCopiedInstruction(*operand, copy, {}); + TF_RETURN_IF_ERROR(instruction->ReplaceOperandWith(0, copy)); + operand = copy; + } else { + *operand->mutable_shape() = shape; + } + *instruction->mutable_shape() = shape; + } } } return Status::OK(); @@ -1761,6 +1795,18 @@ StatusOr LayoutAssignment::Run(HloModule* module) { } TF_RETURN_IF_ERROR(Init()); + // Verify computation layout is sane. + const HloComputation* entry = module->entry_computation(); + TF_RET_CHECK(entry_computation_layout_->parameter_count() == + entry->num_parameters()); + for (int64 i = 0; i < entry->num_parameters(); ++i) { + TF_RET_CHECK( + ShapeUtil::Compatible(entry_computation_layout_->parameter_shape(i), + entry->parameter_instruction(i)->shape())); + } + TF_RET_CHECK(ShapeUtil::Compatible(entry_computation_layout_->result_shape(), + entry->root_instruction()->shape())); + // We do two passes. The first one we pass a nullptr ComputationLayout to // the RunOnComputation() calls (for non entry computations), and we register // the ComputationLayout which are naturally flowing in DFS fashion to the @@ -1812,6 +1858,108 @@ StatusOr LayoutAssignment::Run(HloModule* module) { return true; } +/* static */ +bool LayoutAssignment::InstructionCanChangeLayout( + const HloInstruction* instruction) { + switch (instruction->opcode()) { + case HloOpcode::kAbs: + case HloOpcode::kAdd: + case HloOpcode::kAnd: + case HloOpcode::kAtan2: + case HloOpcode::kBitcastConvert: + case HloOpcode::kCeil: + case HloOpcode::kClamp: + case HloOpcode::kClz: + case HloOpcode::kComplex: + case HloOpcode::kConcatenate: + case HloOpcode::kConditional: + case HloOpcode::kConvert: + case HloOpcode::kCos: + case HloOpcode::kCrossReplicaSum: + case HloOpcode::kAllToAll: + case HloOpcode::kCollectivePermute: + case HloOpcode::kDivide: + case HloOpcode::kDynamicSlice: + case HloOpcode::kDynamicUpdateSlice: + case HloOpcode::kEq: + case HloOpcode::kExp: + case HloOpcode::kExpm1: + case HloOpcode::kFft: + case HloOpcode::kFloor: + case HloOpcode::kGe: + case HloOpcode::kGt: + case HloOpcode::kImag: + case HloOpcode::kIsFinite: + case HloOpcode::kLe: + case HloOpcode::kLog: + case HloOpcode::kLog1p: + case HloOpcode::kLt: + case HloOpcode::kMap: + case HloOpcode::kMaximum: + case HloOpcode::kMinimum: + case HloOpcode::kMultiply: + case HloOpcode::kNe: + case HloOpcode::kNegate: + case HloOpcode::kNot: + case HloOpcode::kOr: + case HloOpcode::kXor: + case HloOpcode::kPad: + case HloOpcode::kPower: + case HloOpcode::kReal: + case HloOpcode::kReducePrecision: + case HloOpcode::kReduceWindow: + case HloOpcode::kRemainder: + case HloOpcode::kReverse: + case HloOpcode::kRoundNearestAfz: + case HloOpcode::kScatter: + case HloOpcode::kSelect: + case HloOpcode::kSelectAndScatter: + case HloOpcode::kShiftLeft: + case HloOpcode::kShiftRightArithmetic: + case HloOpcode::kShiftRightLogical: + case HloOpcode::kSign: + case HloOpcode::kSin: + case HloOpcode::kSlice: + case HloOpcode::kSort: + case HloOpcode::kSubtract: + case HloOpcode::kTanh: + case HloOpcode::kTupleSelect: + case HloOpcode::kWhile: + return false; + case HloOpcode::kBatchNormGrad: + case HloOpcode::kBatchNormInference: + case HloOpcode::kBatchNormTraining: + case HloOpcode::kBitcast: + case HloOpcode::kBroadcast: + case HloOpcode::kCall: + case HloOpcode::kConstant: + case HloOpcode::kConvolution: + case HloOpcode::kCopy: + case HloOpcode::kCustomCall: + case HloOpcode::kDomain: + case HloOpcode::kDot: + case HloOpcode::kFusion: + case HloOpcode::kGather: + case HloOpcode::kGetTupleElement: + case HloOpcode::kInfeed: + case HloOpcode::kIota: + case HloOpcode::kOutfeed: + case HloOpcode::kParameter: + case HloOpcode::kRecv: + case HloOpcode::kRecvDone: + case HloOpcode::kReduce: + case HloOpcode::kReshape: + case HloOpcode::kRng: + case HloOpcode::kSend: + case HloOpcode::kSendDone: + case HloOpcode::kAfterAll: + case HloOpcode::kTrace: + case HloOpcode::kTranspose: + case HloOpcode::kTuple: + return true; + } +} + Status LayoutAssignment::Init() { computation_layouts_.clear(); *entry_computation_layout_ = saved_entry_computation_layout_; diff --git a/tensorflow/compiler/xla/service/layout_assignment.h b/tensorflow/compiler/xla/service/layout_assignment.h index 3e000ec2df4c6deb9e482d9e2cb76773905f2770..cb56f4cd19ded036ef521a579eb7d6ea7f3b6268 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.h +++ b/tensorflow/compiler/xla/service/layout_assignment.h @@ -25,6 +25,8 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -38,8 +40,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/flatmap.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -228,8 +228,8 @@ class LayoutConstraints { // Array-shaped buffers which have not yet been constrained. std::set unconstrained_buffer_ids_; - mutable tensorflow::gtl::FlatMap> + mutable absl::flat_hash_map> buffer_sets_cache_; HloComputation* computation_; @@ -281,11 +281,16 @@ class ChannelLayoutConstraints { // HLO pass which assigns layouts to all instructions in the HLO module while // satisfying all necessary invariants and minimizing cost. -class LayoutAssignment : public HloPassInterface { +class LayoutAssignment : public HloModulePass { public: // entry_computation_layout is modified to populate a layout for the result in // the case that no particular layout is requested. // + // instruction_can_change_layout_func is a function object that determines + // whether an instruction can change layouts. An instruction not being able to + // change layout means that it requires operands with the same rank as the + // output to have the same layout as the output. + // // channel_constraints is both an input and output. Any sends or recvs that // are present in channel_constraints will be laid out as constrained. Any // unconstrained sends or recvs will be laid out as locally optimal and their @@ -295,6 +300,8 @@ class LayoutAssignment : public HloPassInterface { // within any module passed to `Run`. explicit LayoutAssignment( ComputationLayout* entry_computation_layout, + std::function + instruction_can_change_layout_func = InstructionCanChangeLayout, ChannelLayoutConstraints* channel_constraints = nullptr); ~LayoutAssignment() override {} absl::string_view name() const override { return "layout-assignment"; } @@ -303,6 +310,11 @@ class LayoutAssignment : public HloPassInterface { // (any layouts were changed). StatusOr Run(HloModule* module) override; + // Determines whether an instruction can change layouts. An instruction not + // being able to change layout means that it requires operands with the same + // rank as the output to have the same layout as the output. + static bool InstructionCanChangeLayout(const HloInstruction* instruction); + protected: // These methods, invoked by PropagateConstraints, propagate a layout // constraint to its neighbors (i.e. operands and users) in order to minimize @@ -321,19 +333,6 @@ class LayoutAssignment : public HloPassInterface { const ResultLayoutConstraint& layout_constraint, LayoutConstraints* constraints); - // By default LayoutAssignment ensures that inputs and outputs of CustomCalls - // have the "major-first" layout (i.e. {n, n-1, ..., 0}). - // - // If this function returns true, LayoutAssignment does not set a layout for - // the given CustomCall. It's up to the backend to set one in - // AddBackendConstraints, if necessary. - // - // Precondition: instruction->opcode() == HloOpcode::kCustomCall. - virtual bool CustomCallRequiresMajorFirstLayout( - const HloInstruction* /*instruction*/) { - return true; - } - // Called after layouts of an instruction have been finalized to allow // subclasses to check for platform specific assumptions. virtual Status Verify(const HloInstruction* instruction) { @@ -499,7 +498,7 @@ class LayoutAssignment : public HloPassInterface { // Every copy added to the module by the layout assignment pass is registered // here. - tensorflow::gtl::FlatSet added_copies_; + absl::flat_hash_set added_copies_; // The pointer to the channel layout constraints passed in with the // constructor. If not nullptr, this is an input/output argument. @@ -516,8 +515,10 @@ class LayoutAssignment : public HloPassInterface { // The set of HLO instructions which lacked any layout constraint, thus // receiving propagated default layouts. - tensorflow::gtl::FlatSet - unconstrained_layout_instructions_; + absl::flat_hash_set unconstrained_layout_instructions_; + + std::function + instruction_can_change_layout_func_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc index 6d05fa5fe290fb616b824d4fcd49ca2385d1dbb8..a831751fa96f8cef233e16fe02378ac036efc8ab 100644 --- a/tensorflow/compiler/xla/service/layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/algebraic_simplifier.h" @@ -34,13 +35,12 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/gtl/array_slice.h" namespace op = xla::testing::opcode_matchers; @@ -49,13 +49,14 @@ namespace { using ::testing::ElementsAre; -class LayoutAssignmentTest : public HloTestBase { +class LayoutAssignmentTest : public HloVerifiedTestBase { protected: void AssignLayouts(HloModule* module, ComputationLayout* entry_computation_layout, ChannelLayoutConstraints* channel_constraints = nullptr) { LayoutAssignment layout_assignment( - entry_computation_layout, /*channel_constraints=*/channel_constraints); + entry_computation_layout, LayoutAssignment::InstructionCanChangeLayout, + /*channel_constraints=*/channel_constraints); EXPECT_IS_OK(layout_assignment.Run(module).status()); } @@ -64,6 +65,27 @@ class LayoutAssignmentTest : public HloTestBase { FindInstruction(module, name)->shape().layout().minor_to_major(); return std::vector(minor_to_major.begin(), minor_to_major.end()); } + + void ExpectLayoutIs(const Shape& shape, + absl::Span minor_to_major) { + const Layout expected = LayoutUtil::MakeLayout(minor_to_major); + EXPECT_TRUE(LayoutUtil::Equal(shape.layout(), expected)) + << "Expected layout " << expected << ", actual " << shape.layout(); + } + + void ExpectTupleLayoutIs( + const Shape& shape, + std::initializer_list> minor_to_majors) { + int i = 0; + for (const absl::Span minor_to_major : minor_to_majors) { + const Layout expected = LayoutUtil::MakeLayout(minor_to_major); + const Layout& actual = ShapeUtil::GetTupleElementShape(shape, i).layout(); + EXPECT_TRUE(LayoutUtil::Equal(actual, expected)) + << "Expected tuple element " << i << " layout " << expected + << ", actual " << actual; + ++i; + } + } }; TEST_F(LayoutAssignmentTest, ComputationLayout) { @@ -91,7 +113,7 @@ TEST_F(LayoutAssignmentTest, ComputationLayout) { *computation_layout.mutable_parameter_layout(0) = shape_layout; *computation_layout.mutable_parameter_layout(1) = shape_layout; *computation_layout.mutable_result_layout() = shape_layout; - AssignLayouts(module.get(), &computation_layout); + AssignLayouts(module, &computation_layout); EXPECT_TRUE(LayoutUtil::Equal(layout, param0->shape().layout())); EXPECT_TRUE(LayoutUtil::Equal(layout, param1->shape().layout())); EXPECT_TRUE(LayoutUtil::Equal(layout, add->shape().layout())); @@ -127,7 +149,7 @@ TEST_F(LayoutAssignmentTest, ComputationLayoutMixedLayout) { *computation_layout.mutable_parameter_layout(1) = row_major; *computation_layout.mutable_result_layout() = col_major; - AssignLayouts(module.get(), &computation_layout); + AssignLayouts(module, &computation_layout); EXPECT_TRUE(LayoutUtil::Equal(col_major_layout, param0->shape().layout())); EXPECT_TRUE(LayoutUtil::Equal(row_major_layout, param1->shape().layout())); EXPECT_TRUE(LayoutUtil::Equal( @@ -145,7 +167,7 @@ TEST_F(LayoutAssignmentTest, FusionInstruction) { {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout(minor_to_major)); auto constant_literal2 = LiteralUtil::CreateR2WithLayout( {{5.0, 6.0}, {7.0, 8.0}}, LayoutUtil::MakeLayout(minor_to_major)); - Shape ashape = constant_literal1->shape(); + Shape ashape = constant_literal1.shape(); auto constant1 = builder.AddInstruction( HloInstruction::CreateConstant(std::move(constant_literal1))); @@ -172,7 +194,7 @@ TEST_F(LayoutAssignmentTest, FusionInstruction) { ComputationLayout computation_layout(computation->ComputeProgramShape()); *computation_layout.mutable_result_layout() = shape_layout; - AssignLayouts(module.get(), &computation_layout); + AssignLayouts(module, &computation_layout); EXPECT_TRUE(LayoutUtil::Equal( layout, fusion->fused_parameter(0)->shape().layout())); @@ -213,7 +235,7 @@ TEST_F(LayoutAssignmentTest, TupleLayout) { ComputationLayout computation_layout( module->entry_computation()->ComputeProgramShape()); - AssignLayouts(module.get(), &computation_layout); + AssignLayouts(module, &computation_layout); EXPECT_TRUE( LayoutUtil::LayoutsInShapesEqual(constant0->shape(), constant1->shape())); @@ -243,7 +265,7 @@ TEST_F(LayoutAssignmentTest, TupleSelect) { HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); auto select = builder.AddInstruction(HloInstruction::CreateTernary( - tuple0->shape(), HloOpcode::kSelect, pred, tuple0, tuple1)); + tuple0->shape(), HloOpcode::kTupleSelect, pred, tuple0, tuple1)); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); @@ -255,7 +277,7 @@ TEST_F(LayoutAssignmentTest, TupleSelect) { TF_CHECK_OK(computation_layout.mutable_result_layout()->CopyLayoutFromShape( result_shape)); - AssignLayouts(module.get(), &computation_layout); + AssignLayouts(module, &computation_layout); EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(result_shape, select->shape())); } @@ -294,7 +316,7 @@ TEST_F(LayoutAssignmentTest, ConflictingLayoutTuple) { result_shape)); LayoutAssignment layout_assignment(&computation_layout); - AssignLayouts(module.get(), &computation_layout); + AssignLayouts(module, &computation_layout); // Layout assignment should have deep copied the result of the computation to // address the layout conflict. This results in several Tuple() and @@ -310,7 +332,7 @@ TEST_F(LayoutAssignmentTest, ConflictingLayoutTuple) { EXPECT_TRUE( AlgebraicSimplifier(/*is_layout_sensitive=*/true, [](const Shape&, const Shape&) { return false; }) - .Run(module.get()) + .Run(module) .ValueOrDie()); HloInstruction* root = module->entry_computation()->root_instruction(); // Verify layout of the root and the root's operands. @@ -352,7 +374,7 @@ TEST_F(LayoutAssignmentTest, ElementwiseAndReshape) { *computation_layout.mutable_parameter_layout(0) = ShapeLayout(ashape_with_layout); *computation_layout.mutable_result_layout() = ShapeLayout(bshape_with_layout); - AssignLayouts(module.get(), &computation_layout); + AssignLayouts(module, &computation_layout); auto log_minor_to_major = AsInt64Slice(log->shape().layout().minor_to_major()); @@ -393,7 +415,7 @@ TEST_F(LayoutAssignmentTest, ElementwiseAndTranspose) { *computation_layout.mutable_parameter_layout(0) = ShapeLayout(ashape_with_layout); *computation_layout.mutable_result_layout() = ShapeLayout(bshape_with_layout); - AssignLayouts(module.get(), &computation_layout); + AssignLayouts(module, &computation_layout); EXPECT_TRUE( LayoutUtil::Equal(ashape_with_layout.layout(), log->shape().layout())); @@ -432,7 +454,7 @@ TEST_F(LayoutAssignmentTest, BroadcastAndTranspose) { ShapeLayout(input_shape_with_layout); *computation_layout.mutable_result_layout() = ShapeLayout(output_shape_with_layout); - AssignLayouts(module.get(), &computation_layout); + AssignLayouts(module, &computation_layout); EXPECT_THAT(broadcast->shape().layout().minor_to_major(), ElementsAre(0, 1, 2)); @@ -457,13 +479,13 @@ TEST_F(LayoutAssignmentTest, ReshapeOperandHasMultipleUsers) { auto param = builder.AddInstruction( HloInstruction::CreateParameter(0, f32_4, "param")); auto broadcast = builder.AddInstruction( - HloInstruction::CreateBroadcast(f32_34, param, {3})); + HloInstruction::CreateBroadcast(f32_34, param, {1})); auto transpose = builder.AddInstruction( HloInstruction::CreateTranspose(f32_43, broadcast, {1, 0})); auto tanh = builder.AddInstruction( HloInstruction::CreateUnary(f32_34, HloOpcode::kTanh, broadcast)); auto broadcast2 = builder.AddInstruction( - HloInstruction::CreateBroadcast(f32_234, tanh, {2})); + HloInstruction::CreateBroadcast(f32_234, tanh, {1, 2})); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({transpose, broadcast2})); auto module = CreateNewModule(); @@ -485,7 +507,7 @@ TEST_F(LayoutAssignmentTest, ReshapeOperandHasMultipleUsers) { *computation_layout.mutable_result_layout() = ShapeLayout(ShapeUtil::MakeTupleShape( {transpose_shape_with_layout, broadcast2_shape_with_layout})); - AssignLayouts(module.get(), &computation_layout); + AssignLayouts(module, &computation_layout); EXPECT_THAT(broadcast->shape().layout().minor_to_major(), ElementsAre(0, 1)); EXPECT_THAT(transpose->shape().layout().minor_to_major(), ElementsAre(1, 0)); @@ -551,7 +573,7 @@ TEST_F(LayoutAssignmentTest, MakeOperandsTheSame) { *computation_layout.mutable_parameter_layout(1) = ShapeLayout(param1_shape_with_layout); OperandsMustBeTheSameLayoutAssignment layout_assignment(&computation_layout); - EXPECT_IS_OK(layout_assignment.Run(module.get()).status()); + EXPECT_IS_OK(layout_assignment.Run(module).status()); EXPECT_EQ(HloOpcode::kCopy, concatenate->operand(0)->opcode()); EXPECT_THAT(concatenate->operand(0)->shape().layout().minor_to_major(), @@ -575,7 +597,7 @@ TEST_F(LayoutAssignmentTest, TransposeToBitcastFromOperand) { HloComputation* computation = module->AddEntryComputation(builder.Build(transpose)); ComputationLayout computation_layout(computation->ComputeProgramShape()); - AssignLayouts(module.get(), &computation_layout); + AssignLayouts(module, &computation_layout); EXPECT_TRUE(ShapeUtil::TransposeIsBitcast(transpose->operand(0)->shape(), transpose->shape(), {2, 3, 0, 1})); } @@ -593,7 +615,7 @@ TEST_F(LayoutAssignmentTest, TransposeToBitcastToUser) { HloComputation* computation = module->AddEntryComputation(builder.Build(transpose)); ComputationLayout computation_layout(computation->ComputeProgramShape()); - AssignLayouts(module.get(), &computation_layout); + AssignLayouts(module, &computation_layout); EXPECT_TRUE(ShapeUtil::TransposeIsBitcast(transpose->operand(0)->shape(), transpose->shape(), {2, 3, 0, 1})); } @@ -659,18 +681,18 @@ TEST_F(LayoutAssignmentTest, TransposeWithinFusionDoesNotCrash) { } )"; - auto module = ParseHloString(module_str).ValueOrDie(); + ParseAndVerifyModule(module_str); - module = + std::unique_ptr compiled_module = backend() .compiler() - ->RunHloPasses(std::move(module), backend().default_stream_executor(), + ->RunHloPasses(module().Clone(), backend().default_stream_executor(), /*device_allocator=*/nullptr) .ConsumeValueOrDie(); EXPECT_EQ(Status::OK(), backend() .compiler() - ->RunBackend(std::move(module), + ->RunBackend(std::move(compiled_module), backend().default_stream_executor(), /*device_allocator=*/nullptr) .status()); @@ -699,9 +721,9 @@ TEST_F(LayoutAssignmentTest, GTEInheritsLayoutFromOperand) { } )"; - auto module = ParseHloString(module_str).ValueOrDie(); + ParseAndVerifyModule(module_str); ComputationLayout computation_layout( - module->entry_computation()->ComputeProgramShape()); + module().entry_computation()->ComputeProgramShape()); Shape param_shape = ShapeUtil::MakeTupleShape( {ShapeUtil::MakeShapeWithLayout(F32, {2, 2, 2}, {0, 1, 2}), ShapeUtil::MakeTupleShape({ @@ -713,19 +735,19 @@ TEST_F(LayoutAssignmentTest, GTEInheritsLayoutFromOperand) { param_shape)); computation_layout.mutable_result_layout()->ResetLayout( LayoutUtil::MakeLayout({2, 1, 0})); - AssignLayouts(module.get(), &computation_layout); + AssignLayouts(&module(), &computation_layout); - EXPECT_THAT(LayoutOf(module.get(), "gte0"), ElementsAre(0, 1, 2)); - EXPECT_THAT(LayoutOf(module.get(), "gte1a"), ElementsAre(1, 2, 0)); - EXPECT_THAT(LayoutOf(module.get(), "gte1b"), ElementsAre(2, 0, 1)); - EXPECT_THAT(LayoutOf(module.get(), "fresult"), ElementsAre(2, 1, 0)); - EXPECT_THAT(FindInstruction(module.get(), "gte1") + EXPECT_THAT(LayoutOf(&module(), "gte0"), ElementsAre(0, 1, 2)); + EXPECT_THAT(LayoutOf(&module(), "gte1a"), ElementsAre(1, 2, 0)); + EXPECT_THAT(LayoutOf(&module(), "gte1b"), ElementsAre(2, 0, 1)); + EXPECT_THAT(LayoutOf(&module(), "fresult"), ElementsAre(2, 1, 0)); + EXPECT_THAT(FindInstruction(&module(), "gte1") ->shape() .tuple_shapes(0) .layout() .minor_to_major(), ElementsAre(1, 2, 0)); - EXPECT_THAT(FindInstruction(module.get(), "gte1") + EXPECT_THAT(FindInstruction(&module(), "gte1") ->shape() .tuple_shapes(1) .layout() @@ -785,7 +807,7 @@ TEST_F(LayoutAssignmentTest, ConditionalAsymmetricLayout) { HloComputation* computation = module->AddEntryComputation(builder.Build()); ComputationLayout computation_layout(computation->ComputeProgramShape()); - AssignLayouts(module.get(), &computation_layout); + AssignLayouts(module, &computation_layout); const HloInstruction* true_root = true_computation->root_instruction(); const HloInstruction* false_root = false_computation->root_instruction(); @@ -812,7 +834,7 @@ TEST_F(LayoutAssignmentTest, InternalErrorOnBitcast) { ComputationLayout computation_layout( module->entry_computation()->ComputeProgramShape()); LayoutAssignment layout_assignment(&computation_layout); - Status error_status = layout_assignment.Run(module.get()).status(); + Status error_status = layout_assignment.Run(module).status(); EXPECT_FALSE(error_status.ok()); EXPECT_THAT( error_status.error_message(), @@ -839,7 +861,51 @@ TEST_F(LayoutAssignmentTest, ChannelLayoutMismatch) { } )"; - auto module = ParseHloString(module_str).ValueOrDie(); + ParseAndVerifyModule(module_str); + ComputationLayout computation_layout( + module().entry_computation()->ComputeProgramShape()); + Shape param_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1})}); + TF_ASSERT_OK( + computation_layout.mutable_parameter_layout(0)->CopyLayoutFromShape( + param_shape)); + computation_layout.mutable_result_layout()->ResetLayout( + LayoutUtil::MakeLayout({1, 0})); + + ChannelLayoutConstraints channel_constraints; + AssignLayouts(&module(), &computation_layout, &channel_constraints); + + EXPECT_THAT(LayoutOf(&module(), "gte"), ElementsAre(0, 1)); + EXPECT_THAT(LayoutOf(&module(), "root"), ElementsAre(1, 0)); + EXPECT_TRUE(ShapeUtil::Equal( + ShapeUtil::GetSubshape(FindInstruction(&module(), "send")->shape(), {0}), + ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0}))); +} + +TEST_F(LayoutAssignmentTest, AllReduceLayoutMissmatch) { + // Pin non matching layouts to parameter and root. + const char* module_str = R"( + HloModule test_module + + add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) + } + + ENTRY entry_computation { + param = (f32[2,2]) parameter(0) + gte = f32[2,2] get-tuple-element(param), index=0 + ar.0 = f32[2,2] cross-replica-sum(gte), + all_reduce_id=0, replica_groups={{0}}, to_apply=add, + sharding={maximal device=0} + const = f32[2,2] constant(f32[2,2]{{0,1},{2,3}}) + ROOT ar.1 = f32[2,2] cross-replica-sum(const), + all_reduce_id=0, replica_groups={{0}}, to_apply=add, + sharding={maximal device=1} + })"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); ComputationLayout computation_layout( module->entry_computation()->ComputeProgramShape()); Shape param_shape = ShapeUtil::MakeTupleShape( @@ -854,11 +920,375 @@ TEST_F(LayoutAssignmentTest, ChannelLayoutMismatch) { AssignLayouts(module.get(), &computation_layout, &channel_constraints); EXPECT_THAT(LayoutOf(module.get(), "gte"), ElementsAre(0, 1)); - EXPECT_THAT(LayoutOf(module.get(), "root"), ElementsAre(1, 0)); - EXPECT_TRUE( - ShapeUtil::Equal(ShapeUtil::GetSubshape( - FindInstruction(module.get(), "send")->shape(), {0}), - ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0}))); + EXPECT_THAT(LayoutOf(module.get(), "ar.0"), ElementsAre(0, 1)); + EXPECT_THAT(LayoutOf(module.get(), "ar.1"), ElementsAre(0, 1)); + const HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root->shape().layout().minor_to_major(), ElementsAre(1, 0)); +} + +TEST_F(LayoutAssignmentTest, CopySliceOperandToAvoidImplicitLayoutChange) { + const char* module_str = R"( + HloModule CopySliceOperandToAvoidImplicitLayoutChange + + ENTRY CopySliceOperandToAvoidImplicitLayoutChange { + par0 = f32[3,4]{1,0} parameter(0) + par1 = f32[4,5]{0,1} parameter(1) + slice0 = f32[3,4] slice(par1), slice={[1:4],[1:5]} + ROOT add0 = f32[3,4]{1,0} add(par0,slice0) + } + )"; + + ParseAndVerifyModule(module_str); + auto compiled_module = + backend() + .compiler() + ->RunHloPasses(module().Clone(), backend().default_stream_executor(), + /*device_allocator=*/nullptr) + .ConsumeValueOrDie(); + HloInstruction* root = + compiled_module->entry_computation()->root_instruction(); + Shape shape_copy = ShapeUtil::MakeShapeWithLayout(F32, {4, 5}, {1, 0}); + EXPECT_THAT(root, op::Add(op::Parameter(), + op::Slice(AllOf(op::Copy(op::Parameter(1)), + op::ShapeWithLayout(shape_copy))))); +} + +TEST_F(LayoutAssignmentTest, CopyDSliceOperandToAvoidImplicitLayoutChange) { + const char* module_str = R"( + HloModule CopyDSliceOperandToAvoidImplicitLayoutChange + + ENTRY CopyDSliceOperandToAvoidImplicitLayoutChange { + par0 = f32[3,4]{1,0} parameter(0) + par1 = f32[4,5]{0,1} parameter(1) + par2 = s32[2] parameter(2) + dslice0 = f32[3,4] dynamic-slice(par1, par2), dynamic_slice_sizes={3,4} + ROOT add0 = f32[3,4]{1,0} add(par0,dslice0) + } + )"; + + ParseAndVerifyModule(module_str); + auto compiled_module = + backend() + .compiler() + ->RunHloPasses(module().Clone(), backend().default_stream_executor(), + /*device_allocator=*/nullptr) + .ConsumeValueOrDie(); + HloInstruction* root = + compiled_module->entry_computation()->root_instruction(); + Shape shape_copy = ShapeUtil::MakeShapeWithLayout(F32, {4, 5}, {1, 0}); + EXPECT_THAT(root, + op::Add(op::Parameter(), + op::DynamicSlice(AllOf(op::Copy(op::Parameter(1)), + op::ShapeWithLayout(shape_copy)), + op::Parameter(2)))); +} + +TEST_F(LayoutAssignmentTest, CopyConcatOperandToAvoidImplicitLayoutChange) { + const char* module_str = R"( + HloModule CopyConcatOperandToAvoidImplicitLayoutChange + + ENTRY CopyConcatOperandToAvoidImplicitLayoutChange { + par0 = f32[3,8]{1,0} parameter(0) + par1 = f32[3,5]{0,1} parameter(1) + par2 = f32[3,3]{1,0} parameter(2) + concat0 = f32[3,8] concatenate(f32[3,5] par1, f32[3,3] par2), + dimensions={1} + ROOT add0 = f32[3,8]{1,0} add(par0,concat0) + } + )"; + + ParseAndVerifyModule(module_str); + auto compiled_module = + backend() + .compiler() + ->RunHloPasses(module().Clone(), backend().default_stream_executor(), + /*device_allocator=*/nullptr) + .ConsumeValueOrDie(); + HloInstruction* root = + compiled_module->entry_computation()->root_instruction(); + Shape shape_copy = ShapeUtil::MakeShapeWithLayout(F32, {3, 5}, {1, 0}); + EXPECT_THAT(root, + op::Add(op::Parameter(), + op::Concatenate(AllOf(op::Copy(op::Parameter(1)), + op::ShapeWithLayout(shape_copy)), + op::Parameter(2)))); +} + +TEST_F(LayoutAssignmentTest, + ConvolutionOperandWithImplicitLayoutChangeNotCopied) { + const char* module_str = R"( + HloModule ConvolutionOperandWithImplicitLayoutChangeNotCopied + + ENTRY ConvolutionOperandWithImplicitLayoutChangeNotCopied { + par0 = f32[128,3,230,230]{2,3,1,0} parameter(0) + par1 = f32[7,7,3,64]{3,2,0,1} parameter(1) + ROOT convolution0 = f32[128,64,112,112]{3,2,1,0} convolution(par0, par1), + window={size=7x7 stride=2x2}, dim_labels=bf01_01io->bf01, + feature_group_count=1 + } + )"; + + ParseAndVerifyModule(module_str); + auto compiled_module = + backend() + .compiler() + ->RunHloPasses(module().Clone(), backend().default_stream_executor(), + /*device_allocator=*/nullptr) + .ConsumeValueOrDie(); + HloInstruction* root = + compiled_module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Convolution(op::Parameter(0), op::Parameter(1))); +} + +TEST_F(LayoutAssignmentTest, PropagatingLayoutFromResultToOperand) { + const char* module_str = R"( + HloModule PropagatingLayoutFromResultToOperand + + ENTRY PropagatingLayoutFromResultToOperand { + par0 = f32[4,5]{1,0} parameter(0) + ROOT slice0 = f32[3,4]{0,1} slice(par0), slice={[1:4],[1:5]} + } + )"; + + ParseAndVerifyModule(module_str); + auto compiled_module = + backend() + .compiler() + ->RunHloPasses(module().Clone(), backend().default_stream_executor(), + /*device_allocator=*/nullptr) + .ConsumeValueOrDie(); + HloInstruction* root = + compiled_module->entry_computation()->root_instruction(); + Shape shape_copy = ShapeUtil::MakeShapeWithLayout(F32, {4, 5}, {0, 1}); + EXPECT_THAT(root, op::Slice(AllOf(op::Copy(op::Parameter(0)), + op::ShapeWithLayout(shape_copy)))); +} + +TEST_F(LayoutAssignmentTest, TupleCopyOnLayoutMismatch) { + // The first infeed uses layout {0,1}, while the second uses layout {1,0}. + // The mismatch forces a copy of the tuple. The tuple contains a token, so + // layout assignment will fail if it tries to copy the whole tuple. + const char* module_str = R"( + HloModule TupleCopyOnLayoutMismatch + + condition.1 (tup: (s32[], token[], f32[512,1024]{0,1})) -> pred[] { + tup.1 = (s32[], token[], f32[512,1024]{0,1}) parameter(0) + counter.1 = s32[] get-tuple-element(tup.1), index=0 + five = s32[] constant(5) + ROOT lt = pred[] less-than(counter.1, five) + } + + body.2 (tup: (s32[], token[], f32[512,1024]{0,1})) -> (s32[], token[], f32[512,1024]{0,1}) { + tup.2 = (s32[], token[], f32[512,1024]{0,1}) parameter(0) + counter.2 = s32[] get-tuple-element(tup.2), index=0 + tok.2 = token[] get-tuple-element(tup.2), index=1 + + ifeed.2 = (f32[512,1024]{1,0}, token[]) infeed(tok.2) + next_tok = token[] get-tuple-element(ifeed.2), index=1 + next_buf = f32[512,1024]{1,0} get-tuple-element(ifeed.2), index=0 + + one = s32[] constant(1) + next_counter = s32[] add(counter.2, one) + ROOT tup = (s32[], token[], f32[512,1024]{0,1}) tuple(next_counter, next_tok, next_buf) + } + + ENTRY main () -> f32[512,1024]{0,1} { + start_tok = token[] after-all() + + ifeed.3 = (f32[512,1024]{0,1}, token[]) infeed(start_tok) + itok = token[] get-tuple-element(ifeed.3), index=1 + ibuf = f32[512,1024]{0,1} get-tuple-element(ifeed.3), index=0 + + zero = s32[] constant(0) + itup = (s32[], token[], f32[512,1024]{0,1}) tuple(zero, itok, ibuf) + + loop = (s32[], token[], f32[512,1024]{0,1}) while(itup), condition=condition.1, body=body.2 + ROOT result = f32[512,1024]{0,1} get-tuple-element(loop), index=2 + } + )"; + + ParseAndVerifyModule(module_str); + ComputationLayout computation_layout( + module().entry_computation()->ComputeProgramShape()); + + // Sanity check to verify that there's a layout mismatch. + EXPECT_THAT(LayoutOf(&module(), "ibuf"), ElementsAre(0, 1)); + EXPECT_THAT(LayoutOf(&module(), "next_buf"), ElementsAre(1, 0)); + + AssignLayouts(&module(), &computation_layout); + + // Make sure that layout assignment did not magically eliminate the mismatch, + // in which case the test didn't prove anything. + EXPECT_THAT(LayoutOf(&module(), "ibuf"), ElementsAre(0, 1)); + EXPECT_THAT(LayoutOf(&module(), "next_buf"), ElementsAre(1, 0)); +} + +TEST_F(LayoutAssignmentTest, CustomCallNotLayoutConstrained) { + const char* module_str = R"( +HloModule CustomCallNotLayoutConstrained + +ENTRY %CustomCallWithNotLayoutConstrained (p: f32[42,2,3]) -> f32[1,2,3,4] { + %p = f32[42,2,3] parameter(0) + ROOT %custom-call = f32[1,2,3,4] custom-call(f32[42,2,3] %p), custom_call_target="baz" +} +)"; + // Try with a couple different layouts. In each case the custom calls operand + // and result layout should match that of the computation. + { + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest())); + ComputationLayout computation_layout = module->entry_computation_layout(); + *computation_layout.mutable_parameter_layout(0) = + ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {42, 2, 3}, {0, 2, 1})); + *computation_layout.mutable_result_layout() = ShapeLayout( + ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 3, 4}, {3, 2, 0, 1})); + AssignLayouts(module.get(), &computation_layout); + + HloInstruction* root = module->entry_computation()->root_instruction(); + ASSERT_THAT(root, op::CustomCall(op::Parameter())); + ExpectLayoutIs(root->shape(), {3, 2, 0, 1}); + ExpectLayoutIs(root->operand(0)->shape(), {0, 2, 1}); + } + { + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest())); + ComputationLayout computation_layout = module->entry_computation_layout(); + *computation_layout.mutable_parameter_layout(0) = + ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {42, 2, 3}, {0, 1, 2})); + *computation_layout.mutable_result_layout() = ShapeLayout( + ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 3, 4}, {0, 2, 3, 1})); + AssignLayouts(module.get(), &computation_layout); + + HloInstruction* root = module->entry_computation()->root_instruction(); + ASSERT_THAT(root, op::CustomCall(op::Parameter())); + ExpectLayoutIs(root->shape(), {0, 2, 3, 1}); + ExpectLayoutIs(root->operand(0)->shape(), {0, 1, 2}); + } +} + +TEST_F(LayoutAssignmentTest, CustomCallLayoutConstrained) { + const char* module_str = R"( +HloModule CustomCallLayoutConstrained + +ENTRY %CustomCallWithLayoutConstraints (p0: f32[4,4], p1: f32[2,3]) -> f32[1,2,3,4] { + %p0 = f32[4,4] parameter(0) + %p1 = f32[2,3] parameter(1) + ROOT %custom-call = f32[1,2,3,4]{3,2,0,1} custom-call(f32[4,4] %p0, f32[2,3] %p1), custom_call_target="baz", operand_layout_constraints={f32[4,4]{0,1}, f32[2,3]{1,0}} +} +)"; + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest())); + ComputationLayout computation_layout = module->entry_computation_layout(); + *computation_layout.mutable_parameter_layout(0) = + ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {4, 4}, {1, 0})); + *computation_layout.mutable_parameter_layout(1) = + ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {1, 0})); + *computation_layout.mutable_result_layout() = ShapeLayout( + ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 3, 4}, {2, 1, 0, 3})); + AssignLayouts(module.get(), &computation_layout); + + // The custom call should be partially encapsulated in kCopy instructions + // because of the layout mismatches. + ASSERT_THAT(module->entry_computation()->root_instruction(), + op::Copy(op::CustomCall(op::Copy(), op::Parameter()))); + + const HloInstruction* custom_call = + module->entry_computation()->root_instruction()->operand(0); + ExpectLayoutIs(custom_call->shape(), {3, 2, 0, 1}); + ExpectLayoutIs(custom_call->operand(0)->shape(), {0, 1}); + ExpectLayoutIs(custom_call->operand(1)->shape(), {1, 0}); +} + +TEST_F(LayoutAssignmentTest, CustomCallLayoutConstrainedZeroOperands) { + const char* module_str = R"( +HloModule CustomCallLayoutConstrainedZeroOperands + +ENTRY %CustomCallLayoutConstrainedZeroOperands () -> f32[1,2,3,4] { + ROOT %custom-call = f32[1,2,3,4]{3,2,0,1} custom-call(), custom_call_target="baz", operand_layout_constraints={} +} +)"; + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest())); + ComputationLayout computation_layout = module->entry_computation_layout(); + *computation_layout.mutable_result_layout() = ShapeLayout( + ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 3, 4}, {2, 1, 0, 3})); + AssignLayouts(module.get(), &computation_layout); + + ASSERT_THAT(module->entry_computation()->root_instruction(), + op::Copy(op::CustomCall())); + + const HloInstruction* custom_call = + module->entry_computation()->root_instruction()->operand(0); + ExpectLayoutIs(custom_call->shape(), {3, 2, 0, 1}); +} + +TEST_F(LayoutAssignmentTest, CustomCallLayoutConstrainedTupleOperand) { + const char* module_str = R"( +HloModule CustomCallLayoutConstrainedTupleOperand + +ENTRY %CustomCallLayoutConstrainedTupleOperand (p0: f32[4,4], p1: f32[2,3]) -> f32[1,2,3,4] { + %p0 = f32[4,4] parameter(0) + %p1 = f32[2,3] parameter(1) + %tuple = (f32[4,4], f32[2,3]) tuple(%p0, %p1) + ROOT %custom-call = f32[1,2,3,4]{3,2,0,1} custom-call(%tuple), custom_call_target="baz", operand_layout_constraints={(f32[4,4]{1,0}, f32[2,3]{0,1})} +} +)"; + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest())); + ComputationLayout computation_layout = module->entry_computation_layout(); + *computation_layout.mutable_parameter_layout(0) = + ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {4, 4}, {1, 0})); + *computation_layout.mutable_parameter_layout(1) = + ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {1, 0})); + *computation_layout.mutable_result_layout() = ShapeLayout( + ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 3, 4}, {2, 1, 0, 3})); + AssignLayouts(module.get(), &computation_layout); + + HloInstruction* root = module->entry_computation()->root_instruction(); + ExpectLayoutIs(root->shape(), {2, 1, 0, 3}); + + ASSERT_THAT(module->entry_computation()->root_instruction(), + op::Copy(op::CustomCall(op::Tuple()))); + + const HloInstruction* custom_call = + module->entry_computation()->root_instruction()->operand(0); + ExpectLayoutIs(custom_call->shape(), {3, 2, 0, 1}); + ExpectTupleLayoutIs(custom_call->operand(0)->shape(), {{1, 0}, {0, 1}}); +} + +TEST_F(LayoutAssignmentTest, CustomCallLayoutConstrainedTupleResult) { + const char* module_str = R"( +HloModule CustomCallLayoutConstrainedTupleResult + +ENTRY %CustomCallLayoutConstrainedTupleResult (p0: f32[4,4]) -> (f32[4,4]{1,0}, f32[2,3]{0,1}) { + %p0 = f32[4,4] parameter(0) + ROOT %custom-call = (f32[4,4]{1,0}, f32[2,3]{0,1}) custom-call(%p0), custom_call_target="baz", operand_layout_constraints={f32[4,4]{1,0}} +} +)"; + // Try with a couple different layouts. In each case the custom calls operand + // and result layout should match that of the computation. + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest())); + ComputationLayout computation_layout = module->entry_computation_layout(); + *computation_layout.mutable_parameter_layout(0) = + ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {4, 4}, {1, 0})); + *computation_layout.mutable_result_layout() = + ShapeLayout(ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShapeWithLayout(F32, {4, 4}, {1, 0}), + ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {1, 0})})); + AssignLayouts(module.get(), &computation_layout); + + ExpectTupleLayoutIs(module->result_shape(), {{1, 0}, {1, 0}}); + + const HloInstruction* custom_call = + FindInstruction(module.get(), "custom-call"); + ExpectTupleLayoutIs(custom_call->shape(), {{1, 0}, {0, 1}}); } } // namespace diff --git a/tensorflow/compiler/xla/service/llvm_compiler.cc b/tensorflow/compiler/xla/service/llvm_compiler.cc index b17c9d504501a907e27d5152e0082799e87443c7..d287aa4ec7bbcd11f51ea07cd2a1572e59f0d6c6 100644 --- a/tensorflow/compiler/xla/service/llvm_compiler.cc +++ b/tensorflow/compiler/xla/service/llvm_compiler.cc @@ -21,8 +21,24 @@ limitations under the License. #endif namespace xla { +Status LLVMCompiler::RunHloPassesOnModuleGroup( + HloModuleGroup* module_group, se::StreamExecutor* executor, + DeviceMemoryAllocator* device_allocator) { + return Unimplemented( + "Model partitioning not implemented for the CPU/GPU compilers!"); +} + +StatusOr>> +LLVMCompiler::RunBackendOnModuleGroup( + std::unique_ptr module_group, + std::vector> stream_exec, + DeviceMemoryAllocator* device_allocator) { + return Unimplemented( + "Model partitioning not implemented for the CPU/GPU compilers!"); +} + StatusOr>> LLVMCompiler::Compile( - std::vector> modules, + std::unique_ptr module_group, std::vector> stream_execs, DeviceMemoryAllocator* device_allocator) { // Tensorflow tries to enable the following behaviors in all its threads: @@ -38,6 +54,8 @@ StatusOr>> LLVMCompiler::Compile( tensorflow::port::ScopedDontFlushDenormal dont_flush_denormals; std::vector> result; + std::vector> modules = + module_group->ConsumeModules(); for (size_t i = 0; i < modules.size(); i++) { if (stream_execs[i].size() != 1) { return Unimplemented( diff --git a/tensorflow/compiler/xla/service/llvm_compiler.h b/tensorflow/compiler/xla/service/llvm_compiler.h index f1c623508c5307f2b1c036d3ec6823b75c7eda13..86abd5da0189feb0eadfde3d6dbab446eb2be900 100644 --- a/tensorflow/compiler/xla/service/llvm_compiler.h +++ b/tensorflow/compiler/xla/service/llvm_compiler.h @@ -69,8 +69,17 @@ class LLVMCompiler : public Compiler { using Compiler::RunBackend; using Compiler::RunHloPasses; + Status RunHloPassesOnModuleGroup( + HloModuleGroup* module_group, se::StreamExecutor* executor, + DeviceMemoryAllocator* device_allocator) override; + + StatusOr>> RunBackendOnModuleGroup( + std::unique_ptr module_group, + std::vector> stream_exec, + DeviceMemoryAllocator* device_allocator) override; + StatusOr>> Compile( - std::vector> modules, + std::unique_ptr module_group, std::vector> stream_execs, DeviceMemoryAllocator* device_allocator) override; diff --git a/tensorflow/compiler/xla/service/llvm_ir/BUILD b/tensorflow/compiler/xla/service/llvm_ir/BUILD index fc3289f30d5399cf7ef3320ebef6d6ff235dbe44..5f7ad81d82978d0a752b33d12b72e16f0c1c6826 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/BUILD +++ b/tensorflow/compiler/xla/service/llvm_ir/BUILD @@ -38,6 +38,8 @@ cc_library( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:logical_buffer", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", "@llvm//:core", ], @@ -71,6 +73,7 @@ cc_library( "//tensorflow/compiler/xla/service:name_uniquer", "//tensorflow/core:lib", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", "@llvm//:core", "@llvm//:support", "@llvm//:target", @@ -92,6 +95,7 @@ cc_library( "//tensorflow/core:lib", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", "@llvm//:core", ], ) @@ -108,6 +112,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", "@llvm//:core", ], ) @@ -125,6 +130,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/strings:str_format", "@llvm//:core", ], ) @@ -162,6 +168,7 @@ cc_library( "//tensorflow/compiler/xla/service:elemental_ir_emitter", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/core:lib", + "@com_google_absl//absl/types:span", "@llvm//:core", ], ) @@ -197,8 +204,8 @@ cc_library( "//tensorflow/compiler/xla/service/gpu:partition_assignment", "//tensorflow/core:lib", "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:optional", "@llvm//:core", + "@llvm//:support", ], ) @@ -213,6 +220,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/types:span", "@llvm//:core", ], ) @@ -248,3 +256,12 @@ cc_library( "@llvm//:core", ], ) + +cc_library( + name = "ir_builder_mixin", + srcs = [], + hdrs = ["ir_builder_mixin.h"], + deps = [ + "@llvm//:core", + ], +) diff --git a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc index e5370eca56f2e3a891523ba2b72961d66ec809aa..643ecd0fbaa546c551097b29e74ccd49418e1466 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h" -#include +#include #include "llvm/IR/MDBuilder.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" @@ -164,9 +164,7 @@ llvm::MDNode* AliasAnalysis::GetNoaliasMetadataForBuffer( add_buffers_to_worklist(operand); } - tensorflow::gtl::FlatSet - buffers; + std::set buffers; for (const LogicalBuffer* buffer : worklist) { // Skip buffers which cannot be added to the noalias set. if (!assignment.HasAllocation(*buffer) || diff --git a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h index 8d9fa99d82b4e49b653d9f05cc9baa5e3fdcefa6..2b46b3c3964b15548dbacc8b0ada0047a0fa85b6 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h +++ b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h @@ -16,14 +16,13 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_ALIAS_ANALYSIS_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_ALIAS_ANALYSIS_H_ +#include "absl/container/flat_hash_map.h" #include "absl/strings/str_cat.h" #include "llvm/IR/Module.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/gtl/flatmap.h" -#include "tensorflow/core/lib/gtl/flatset.h" namespace xla { namespace llvm_ir { @@ -77,14 +76,14 @@ class AliasAnalysis { // A map from a buffer slice to metadata corresponding to its alias.scope // metadata. The index kParameterAliasSet is used to hold aliasing // information for parameters. - tensorflow::gtl::FlatMap + absl::flat_hash_map alias_scope_metadata_; // A map from a buffer slice to metadata corresponding to its noalias // metadata. - tensorflow::gtl::FlatMap + absl::flat_hash_map noalias_metadata_; }; diff --git a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis_test.cc b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis_test.cc index fe5ec1cc66d06e85ce70625ef7cf764a37b29166..b6ae4932f5707f1d15af1e09a735a7de2e48fac5 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis_test.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis_test.cc @@ -61,7 +61,7 @@ ENTRY while3 { ; CHECK: store float %[[add_result]], float* %[[store_dest:.*]], !alias.scope ![[alias_scope_md_for_store:[0-9]+]] ; ; CHECK-LABEL: @condition(i8* %retval, i8* noalias %run_options, i8** noalias %params -; CHECK: %[[cond_state_buf_ptr:.*]] = getelementptr inbounds i8*, i8** %temps, i64 0 +; CHECK: %[[cond_state_buf_ptr:.*]] = getelementptr inbounds i8*, i8** %buffer_table, i64 0 ; CHECK: %[[cond_state_buf_untyped:.*]] = load i8*, i8** %[[cond_state_buf_ptr]] ; CHECK: %[[cond_state_buf_typed:.*]] = bitcast i8* %[[cond_state_buf_untyped]] to float* ; CHECK: load float, float* %[[cond_state_buf_typed]], !alias.scope ![[alias_scope_md_for_store]], !noalias ![[noalias_md_for_load:.*]] diff --git a/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc b/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc index ad350613dd23f4a477c422a6311f1b03bc681574..cc2e862f2eb9a49099c5f90efe1b29fb77c8f106 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc @@ -99,9 +99,10 @@ static Status EmitDynamicUpdateSliceInPlaceImpl( return LoopEmitter(loop_body_emitter, update_shape, b).EmitLoop(name); } -Status EmitDynamicUpdateSliceInPlace( - tensorflow::gtl::ArraySlice operand_arrays, - const IrArray& output_array, absl::string_view name, llvm::IRBuilder<>* b) { +Status EmitDynamicUpdateSliceInPlace(absl::Span operand_arrays, + const IrArray& output_array, + absl::string_view name, + llvm::IRBuilder<>* b) { VLOG(2) << "EmitDynamicUpdateSliceInPlace for " << name; // No need to use operand_arrays[0], the input array of the @@ -129,8 +130,7 @@ Status EmitDynamicUpdateSliceInPlace( // // Emits a sequential loop if launch_dimensions is null. static Status EmitFusedDynamicUpdateSliceInPlaceImpl( - HloInstruction* fusion, - tensorflow::gtl::ArraySlice fusion_operand_arrays, + HloInstruction* fusion, absl::Span fusion_operand_arrays, const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter, const gpu::LaunchDimensions* launch_dimensions, llvm::IRBuilder<>* b) { CHECK_EQ(fusion->opcode(), HloOpcode::kFusion); @@ -173,8 +173,7 @@ static Status EmitFusedDynamicUpdateSliceInPlaceImpl( } Status EmitFusedDynamicUpdateSliceInPlace( - HloInstruction* fusion, - tensorflow::gtl::ArraySlice fusion_operand_arrays, + HloInstruction* fusion, absl::Span fusion_operand_arrays, const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter, llvm::IRBuilder<>* b) { return EmitFusedDynamicUpdateSliceInPlaceImpl( @@ -183,8 +182,7 @@ Status EmitFusedDynamicUpdateSliceInPlace( } Status EmitParallelFusedDynamicUpdateSliceInPlace( - HloInstruction* fusion, - tensorflow::gtl::ArraySlice fusion_operand_arrays, + HloInstruction* fusion, absl::Span fusion_operand_arrays, const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter, const gpu::LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* b) { return EmitFusedDynamicUpdateSliceInPlaceImpl( diff --git a/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h b/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h index e1631a62ae8486f03a4fe8fcb32f1b49d5dd2339..fb3e4eb97cae06f2a0c87dd7118b8332048df56e 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h +++ b/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h @@ -63,25 +63,24 @@ inline bool CanEmitFusedDynamicUpdateSliceInPlace( // Emits IR for running the given dynamic-update-slice op in-place -- that is, // where the input and output buffers share the same slice, so we can simply // modify the input/output buffer without touching any of the other elements. -Status EmitDynamicUpdateSliceInPlace( - tensorflow::gtl::ArraySlice operand_arrays, - const IrArray& output_array, absl::string_view name, llvm::IRBuilder<>* b); +Status EmitDynamicUpdateSliceInPlace(absl::Span operand_arrays, + const IrArray& output_array, + absl::string_view name, + llvm::IRBuilder<>* b); // Given a loop-fusion node whose root is a dynamic-update-slice op whose // array-to-be-updated and output share the same buffer slice, emits // (sequential) code for a fusion node that does the dynamic-update-slice in // place. Status EmitFusedDynamicUpdateSliceInPlace( - HloInstruction* fusion, - tensorflow::gtl::ArraySlice fusion_operand_arrays, + HloInstruction* fusion, absl::Span fusion_operand_arrays, const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter, llvm::IRBuilder<>* b); // Same as EmitFusedDynamicUpdateSliceInPlace, except emits a parallel loop with // the given launch dimensions. Status EmitParallelFusedDynamicUpdateSliceInPlace( - HloInstruction* fusion, - tensorflow::gtl::ArraySlice fusion_operand_arrays, + HloInstruction* fusion, absl::Span fusion_operand_arrays, const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter, const gpu::LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* b); diff --git a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc index 72ede377e1a505d5e4916915e18827e1a0f3fdf9..b606c993a2d58a6d177af10de7b214de130c2279 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc @@ -98,7 +98,7 @@ Status FusedIrEmitter::HandleGetTupleElement( return Unimplemented( "GetTupleElement fusion currently only supports" " parameter operands, but found operand: %s", - operand->name().c_str()); + operand->name()); } // Emit code to lookup tuple element pointer, and store it in 'gte_values_'. llvm::Value* tuple_element_ptr = llvm_ir::EmitGetTupleElement( @@ -147,7 +147,7 @@ Status FusedIrEmitter::HandleParameter(HloInstruction* parameter) { } Status FusedIrEmitter::HandleTuple(HloInstruction* tuple) { - tensorflow::gtl::ArraySlice operands(tuple->operands()); + absl::Span operands(tuple->operands()); std::vector operand_elemental_ir_types; for (HloInstruction* operand : operands) { operand_elemental_ir_types.push_back(llvm_ir::PrimitiveTypeToIrType( diff --git a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h index 30471480c4fb3ce3bf3226a28e9d2ffa79ae5f29..44d21fa750a532633f46614002d59c90fc0b5d40 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h +++ b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Value.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" @@ -29,7 +30,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" namespace xla { @@ -54,7 +54,7 @@ class FusedIrEmitter : public DfsHloVisitorWithDefault { public: using Generator = llvm_ir::ElementGenerator; - FusedIrEmitter(tensorflow::gtl::ArraySlice parameter_arrays, + FusedIrEmitter(absl::Span parameter_arrays, ElementalIrEmitter* elemental_emitter) : parameter_arrays_(parameter_arrays), tiled_parameter_info_(nullptr), @@ -94,7 +94,7 @@ class FusedIrEmitter : public DfsHloVisitorWithDefault { private: // Arrays of parameters of fusion instruction - tensorflow::gtl::ArraySlice parameter_arrays_; + absl::Span parameter_arrays_; const llvm_ir::TiledParameterInfo* tiled_parameter_info_; ElementalIrEmitter* elemental_emitter_; diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc index 6971220022d9d3fe5caded731977df4dfffd2992..67f7423121177e2ca1e3384341dad2644c8f5e34 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc @@ -73,7 +73,7 @@ IrArray::Index::Index(llvm::Value* linear, const Shape& shape, Delinearize(&multidim_, linear, shape, b); } -IrArray::Index::Index(tensorflow::gtl::ArraySlice multidim, +IrArray::Index::Index(absl::Span multidim, llvm::Value* linear, const Shape& shape) : multidim_(multidim.begin(), multidim.end()), linear_(linear), @@ -92,7 +92,7 @@ IrArray::Index::Index(tensorflow::gtl::ArraySlice multidim, << " should have a layout."; } -IrArray::Index::Index(tensorflow::gtl::ArraySlice multidim, +IrArray::Index::Index(absl::Span multidim, const Shape& shape, llvm::IRBuilder<>* b) : multidim_(multidim.begin(), multidim.end()), layout_(shape.layout()), @@ -147,16 +147,15 @@ IrArray::Index IrArray::Index::SourceIndexOfReshape( // indices in the same common factor. for (ssize_t k = common_factors.size() - 2; k >= 0; --k) { llvm::Value* logical_linear_index = - Index(tensorflow::gtl::ArraySlice( - multidim_, common_factors[k].second, + Index(absl::Span(multidim_).subspan( + common_factors[k].second, common_factors[k + 1].second - common_factors[k].second), index_type_) - .Linearize( - tensorflow::gtl::ArraySlice( - AsInt64Slice(output_shape.dimensions()), - common_factors[k].second, - common_factors[k + 1].second - common_factors[k].second), - builder); + .Linearize(AsInt64Slice(output_shape.dimensions()) + .subspan(common_factors[k].second, + common_factors[k + 1].second - + common_factors[k].second), + builder); // Delinearizes logical_linear_index for the source array in row-major // collapsed order. The first rank-1 indices are the remainder of the // linear index by each dimension size. @@ -185,9 +184,8 @@ IrArray::Index IrArray::Index::SourceIndexOfReshape( } IrArray::Index IrArray::Index::SourceIndexOfSlice( - const Shape& shape, tensorflow::gtl::ArraySlice starts, - tensorflow::gtl::ArraySlice strides, - llvm::IRBuilder<>* builder) const { + const Shape& shape, absl::Span starts, + absl::Span strides, llvm::IRBuilder<>* builder) const { Index source_index(index_type_, multidim_.size()); for (int i = 0; i < multidim_.size(); ++i) { int64 stride = strides[i]; @@ -208,7 +206,7 @@ IrArray::Index IrArray::Index::SourceIndexOfSlice( IrArray::Index IrArray::Index::SourceIndexOfTranspose( const Shape& shape, const Shape& operand_shape, - tensorflow::gtl::ArraySlice dimension_mapping, + absl::Span dimension_mapping, llvm::IRBuilder<>* builder) const { std::vector operand_multidim_index = Permute(dimension_mapping, multidim()); @@ -257,7 +255,7 @@ IrArray::Index IrArray::Index::SourceIndexOfBitcast( IrArray::Index IrArray::Index::SourceIndexOfBroadcast( const Shape& shape, const Shape& operand_shape, - tensorflow::gtl::ArraySlice dimension_mapping, + absl::Span dimension_mapping, llvm::IRBuilder<>* builder) const { int64 rank = ShapeUtil::Rank(operand_shape); std::vector source_index(rank); @@ -322,9 +320,8 @@ IrArray::Index IrArray::Index::SourceIndexOfBroadcast( return Index(source_index, linear, operand_shape); } -llvm::Value* IrArray::Index::Linearize( - tensorflow::gtl::ArraySlice dimensions, - llvm::IRBuilder<>* builder) const { +llvm::Value* IrArray::Index::Linearize(absl::Span dimensions, + llvm::IRBuilder<>* builder) const { // Each dimension is multiplied by the product of the sizes of all // earlier dimensions and added to the accumulator logical_linear_index. CHECK_EQ(size(), dimensions.size()); diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.h b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h index e913c109b3ff0e4e7192e501a314aa381a4268b0..f4b05f29c38529b3cce81b4c8ee6fae5c00cafcc 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.h +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h @@ -21,12 +21,12 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Value.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -70,7 +70,7 @@ class IrArray { // Constructs an index from multi-dimensional index "multidim". The linear // index is set to nullptr. - explicit Index(tensorflow::gtl::ArraySlice multidim, + explicit Index(absl::Span multidim, llvm::Type* index_ty = nullptr) : multidim_(multidim.begin(), multidim.end()) { if (size() == 0) { @@ -99,14 +99,14 @@ class IrArray { // that it indexes into. // // Precondition: "shape" has a layout. - Index(tensorflow::gtl::ArraySlice multidim, - const Shape& shape, llvm::IRBuilder<>* b); + Index(absl::Span multidim, const Shape& shape, + llvm::IRBuilder<>* b); // Constructs an index from both a multi-dimensional index and a linear // index. "shape" has the same meaning as that in the constructor that takes // only a linear index. - Index(tensorflow::gtl::ArraySlice multidim, - llvm::Value* linear, const Shape& shape); + Index(absl::Span multidim, llvm::Value* linear, + const Shape& shape); const std::vector& multidim() const { return multidim_; } llvm::Value* linear() const { return linear_; } @@ -145,17 +145,15 @@ class IrArray { // by starting indices `starts` and stride values `strides`. // // Precondition: "this" is an index into a slice whose shape is `shape`. - Index SourceIndexOfSlice(const Shape& shape, - tensorflow::gtl::ArraySlice starts, - tensorflow::gtl::ArraySlice strides, + Index SourceIndexOfSlice(const Shape& shape, absl::Span starts, + absl::Span strides, llvm::IRBuilder<>* builder) const; // Given that "this" is the target index of a transpose from `operand_shape` // to `shape` with the given dimension mapping, returns the source index. - Index SourceIndexOfTranspose( - const Shape& shape, const Shape& operand_shape, - tensorflow::gtl::ArraySlice dimension_mapping, - llvm::IRBuilder<>* builder) const; + Index SourceIndexOfTranspose(const Shape& shape, const Shape& operand_shape, + absl::Span dimension_mapping, + llvm::IRBuilder<>* builder) const; // Given that "this" is the target index of a bitcast from `operand_shape` // to `shape`, returns the source index. @@ -164,14 +162,13 @@ class IrArray { // Given that "this" is the target index of a broadcast from `operand_shape` // to `shape` with the given dimension mapping, returns the source index. - Index SourceIndexOfBroadcast( - const Shape& shape, const Shape& operand_shape, - tensorflow::gtl::ArraySlice dimension_mapping, - llvm::IRBuilder<>* builder) const; + Index SourceIndexOfBroadcast(const Shape& shape, const Shape& operand_shape, + absl::Span dimension_mapping, + llvm::IRBuilder<>* builder) const; // Linearizes the index into the given shape, i.e. reshapes it to rank-1 and // returns the index into the sole dimension 0 of the new shape. - llvm::Value* Linearize(tensorflow::gtl::ArraySlice dimensions, + llvm::Value* Linearize(absl::Span dimensions, llvm::IRBuilder<>* builder) const; llvm::Type* GetType() const { return index_type_; } diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h b/tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h new file mode 100644 index 0000000000000000000000000000000000000000..abc06fb7b4245294df2dc20d25a22ac4fdaeb4cf --- /dev/null +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h @@ -0,0 +1,400 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_IR_BUILDER_MIXIN_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_IR_BUILDER_MIXIN_H_ + +#include "llvm/IR/IRBuilder.h" + +namespace xla { + +// Mixin class that injects more ergonomic versions of llvm::IRBuilder methods +// into a class. Intended to be used as a CRTP base class, like: +// +// class MyIrEmitter : public IrBuilderMixin { +// llvm::IRBuilder<>* builder() { return builder_; } +// +// void EmitFoo(HloInstruction* foo) { +// Add(Mul(...), FPToUI(...)); +// } +// }; + +template +class IrBuilderMixin { + protected: + template + llvm::Value* Add(Args&&... args) { + return mixin_builder()->CreateAdd(std::forward(args)...); + } + + template + llvm::LoadInst* AlignedLoad(Args&&... args) { + return mixin_builder()->CreateAlignedLoad(std::forward(args)...); + } + + template + llvm::StoreInst* AlignedStore(Args&&... args) { + return mixin_builder()->CreateAlignedStore(std::forward(args)...); + } + + template + llvm::AllocaInst* Alloca(Args&&... args) { + return mixin_builder()->CreateAlloca(std::forward(args)...); + } + + template + llvm::Value* And(Args&&... args) { + return mixin_builder()->CreateAnd(std::forward(args)...); + } + + template + llvm::Value* AtomicCmpXchg(Args&&... args) { + return mixin_builder()->CreateAtomicCmpXchg(std::forward(args)...); + } + + template + llvm::Value* AtomicRMW(Args&&... args) { + return mixin_builder()->CreateAtomicRMW(std::forward(args)...); + } + + template + llvm::Value* BitCast(Args&&... args) { + return mixin_builder()->CreateBitCast(std::forward(args)...); + } + + template + llvm::Value* Br(Args&&... args) { + return mixin_builder()->CreateBr(std::forward(args)...); + } + + llvm::CallInst* Call(llvm::Value* callee, + llvm::ArrayRef args = llvm::None, + const llvm::Twine& name = "", + llvm::MDNode* fp_math_tag = nullptr) { + return mixin_builder()->CreateCall(callee, args, name, fp_math_tag); + } + + template + llvm::BranchInst* CondBr(Args&&... args) { + return mixin_builder()->CreateCondBr(std::forward(args)...); + } + + template + llvm::Value* ConstInBoundsGEP1_32(Args&&... args) { + return mixin_builder()->CreateConstInBoundsGEP1_32( + std::forward(args)...); + } + + template + llvm::Value* FAdd(Args&&... args) { + return mixin_builder()->CreateFAdd(std::forward(args)...); + } + + template + llvm::Value* FMul(Args&&... args) { + return mixin_builder()->CreateFMul(std::forward(args)...); + } + + llvm::Value* GEP(llvm::Value* ptr, llvm::ArrayRef idx_list, + const llvm::Twine& name = "") { + return mixin_builder()->CreateGEP(ptr, idx_list, name); + } + + template + llvm::Value* ICmpEQ(Args&&... args) { + return mixin_builder()->CreateICmpEQ(std::forward(args)...); + } + + template + llvm::Value* ICmpNE(Args&&... args) { + return mixin_builder()->CreateICmpNE(std::forward(args)...); + } + + template + llvm::Value* ICmpULE(Args&&... args) { + return mixin_builder()->CreateICmpULE(std::forward(args)...); + } + + template + llvm::Value* ICmpULT(Args&&... args) { + return mixin_builder()->CreateICmpULT(std::forward(args)...); + } + + llvm::Value* InBoundsGEP(llvm::Value* ptr, + llvm::ArrayRef idx_list, + const llvm::Twine& name = "") { + return mixin_builder()->CreateInBoundsGEP(ptr, idx_list, name); + } + + llvm::Value* ExtractValue(llvm::Value* agg, llvm::ArrayRef idxs, + const llvm::Twine& name = "") { + return mixin_builder()->CreateExtractValue(agg, idxs, name); + } + + llvm::Value* InsertValue(llvm::Value* agg, llvm::Value* val, + llvm::ArrayRef idxs, + const llvm::Twine& name = "") { + return mixin_builder()->CreateInsertValue(agg, val, idxs, name); + } + + template + llvm::Value* IntToPtr(Args&&... args) { + return mixin_builder()->CreateIntToPtr(std::forward(args)...); + } + + template + llvm::LoadInst* Load(Args&&... args) { + return mixin_builder()->CreateLoad(std::forward(args)...); + } + + template + llvm::CallInst* MemCpy(Args&&... args) { + return mixin_builder()->CreateMemCpy(std::forward(args)...); + } + + template + llvm::Value* Mul(Args&&... args) { + return mixin_builder()->CreateMul(std::forward(args)...); + } + + template + llvm::Value* NSWAdd(Args&&... args) { + return mixin_builder()->CreateNSWAdd(std::forward(args)...); + } + + template + llvm::Value* NSWMul(Args&&... args) { + return mixin_builder()->CreateNSWMul(std::forward(args)...); + } + + template + llvm::Value* NSWSub(Args&&... args) { + return mixin_builder()->CreateNSWSub(std::forward(args)...); + } + + template + llvm::Value* Or(Args&&... args) { + return mixin_builder()->CreateOr(std::forward(args)...); + } + + template + llvm::Value* PointerCast(Args&&... args) { + return mixin_builder()->CreatePointerCast(std::forward(args)...); + } + + template + llvm::Value* PtrToInt(Args&&... args) { + return mixin_builder()->CreatePtrToInt(std::forward(args)...); + } + + template + llvm::Value* SDiv(Args&&... args) { + return mixin_builder()->CreateSDiv(std::forward(args)...); + } + + template + llvm::Value* Select(Args&&... args) { + return mixin_builder()->CreateSelect(std::forward(args)...); + } + + template + llvm::Value* SRem(Args&&... args) { + return mixin_builder()->CreateSRem(std::forward(args)...); + } + + template + llvm::StoreInst* Store(Args&&... args) { + return mixin_builder()->CreateStore(std::forward(args)...); + } + + template + llvm::Value* UDiv(Args&&... args) { + return mixin_builder()->CreateUDiv(std::forward(args)...); + } + + template + llvm::Value* URem(Args&&... args) { + return mixin_builder()->CreateURem(std::forward(args)...); + } + + template + llvm::Value* VectorSplat(Args&&... args) { + return mixin_builder()->CreateVectorSplat(std::forward(args)...); + } + + template + llvm::Value* ZExtOrTrunc(Args&&... args) { + return mixin_builder()->CreateZExtOrTrunc(std::forward(args)...); + } + + template + llvm::Value* AShr(Args&&... args) { + return mixin_builder()->CreateAShr(std::forward(args)...); + } + + template + llvm::Value* FCmpOEQ(Args&&... args) { + return mixin_builder()->CreateFCmpOEQ(std::forward(args)...); + } + + template + llvm::Value* FCmpOLT(Args&&... args) { + return mixin_builder()->CreateFCmpOLT(std::forward(args)...); + } + + template + llvm::Value* FCmpONE(Args&&... args) { + return mixin_builder()->CreateFCmpONE(std::forward(args)...); + } + + template + llvm::Value* FCmpUNE(Args&&... args) { + return mixin_builder()->CreateFCmpUNE(std::forward(args)...); + } + + template + llvm::Value* FDiv(Args&&... args) { + return mixin_builder()->CreateFDiv(std::forward(args)...); + } + + template + llvm::Value* FNeg(Args&&... args) { + return mixin_builder()->CreateFNeg(std::forward(args)...); + } + + template + llvm::Value* FPCast(Args&&... args) { + return mixin_builder()->CreateFPCast(std::forward(args)...); + } + + template + llvm::Value* FPToSI(Args&&... args) { + return mixin_builder()->CreateFPToSI(std::forward(args)...); + } + + template + llvm::Value* FPToUI(Args&&... args) { + return mixin_builder()->CreateFPToUI(std::forward(args)...); + } + + template + llvm::Value* FPTrunc(Args&&... args) { + return mixin_builder()->CreateFPTrunc(std::forward(args)...); + } + + template + llvm::Value* FRem(Args&&... args) { + return mixin_builder()->CreateFRem(std::forward(args)...); + } + + template + llvm::Value* FSub(Args&&... args) { + return mixin_builder()->CreateFSub(std::forward(args)...); + } + + template + llvm::Value* ICmpSGE(Args&&... args) { + return mixin_builder()->CreateICmpSGE(std::forward(args)...); + } + + template + llvm::Value* ICmpSLT(Args&&... args) { + return mixin_builder()->CreateICmpSLT(std::forward(args)...); + } + + template + llvm::Value* IntCast(Args&&... args) { + return mixin_builder()->CreateIntCast(std::forward(args)...); + } + + template + llvm::Value* LShr(Args&&... args) { + return mixin_builder()->CreateLShr(std::forward(args)...); + } + + template + llvm::Value* MemSet(Args&&... args) { + return mixin_builder()->CreateMemSet(std::forward(args)...); + } + + template + llvm::Value* Neg(Args&&... args) { + return mixin_builder()->CreateNeg(std::forward(args)...); + } + + template + llvm::Value* Not(Args&&... args) { + return mixin_builder()->CreateNot(std::forward(args)...); + } + + template + llvm::PHINode* PHI(Args&&... args) { + return mixin_builder()->CreatePHI(std::forward(args)...); + } + + template + llvm::Value* RetVoid(Args&&... args) { + return mixin_builder()->CreateRetVoid(std::forward(args)...); + } + + template + llvm::Value* SExtOrTrunc(Args&&... args) { + return mixin_builder()->CreateSExtOrTrunc(std::forward(args)...); + } + + template + llvm::Value* Shl(Args&&... args) { + return mixin_builder()->CreateShl(std::forward(args)...); + } + + template + llvm::Value* SIToFP(Args&&... args) { + return mixin_builder()->CreateSIToFP(std::forward(args)...); + } + + template + llvm::Value* Sub(Args&&... args) { + return mixin_builder()->CreateSub(std::forward(args)...); + } + + template + llvm::Value* Trunc(Args&&... args) { + return mixin_builder()->CreateTrunc(std::forward(args)...); + } + + template + llvm::Value* UIToFP(Args&&... args) { + return mixin_builder()->CreateUIToFP(std::forward(args)...); + } + + template + llvm::Value* Unreachable(Args&&... args) { + return mixin_builder()->CreateUnreachable(std::forward(args)...); + } + + template + llvm::Value* Xor(Args&&... args) { + return mixin_builder()->CreateXor(std::forward(args)...); + } + + private: + llvm::IRBuilder<>* mixin_builder() { + return static_cast(this)->builder(); + } +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_IR_BUILDER_MIXIN_H_ diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h index b152cf9275c86ece2e049d193c45e07db22a1170..43fec311f150d6054f6ad24f99db332f90ff94a3 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h +++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h @@ -235,7 +235,7 @@ class KernelSupportLibrary { })); } - using ArgumentVector = tensorflow::gtl::ArraySlice; + using ArgumentVector = absl::Span; // Generates the following control flow structure: // diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc index cb4d1db997c133636dab12393d371b6e5a7452eb..e5fbdbd51b8a9aa14decadedd1eeb3bdbf831738 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc @@ -28,7 +28,7 @@ namespace { // Returns the indices of the first elements of all consecutive subarrays of the // given array. For example: // ConsecutiveSegments({m, m+1, m+2, n, k, k+1}) = {0, 3, 4} -std::vector ConsecutiveSegments(tensorflow::gtl::ArraySlice xs) { +std::vector ConsecutiveSegments(absl::Span xs) { std::vector is = {0}; for (size_t i = 1; i < xs.size(); ++i) { if (1 != xs[i] - xs[i - 1]) { @@ -40,8 +40,7 @@ std::vector ConsecutiveSegments(tensorflow::gtl::ArraySlice xs) { // Merges the sequences of dimensions of the given shape which start at the // given indices `segs`. -Shape MergeDimensions(tensorflow::gtl::ArraySlice segs, - const Shape& shape) { +Shape MergeDimensions(absl::Span segs, const Shape& shape) { std::vector dimensions; for (size_t i = 1; i <= segs.size(); ++i) { dimensions.push_back(std::accumulate( diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h index 8bd06c42c3cd2cb905191572d0a0722e778734f9..5ea05b3188a1c0881e4c0c41625d530aff1b1205 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h +++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h @@ -50,7 +50,7 @@ IrArray::Index GetUnreducedOutputIndex( // for 021 transpose. class TiledParameterInfo { public: - TiledParameterInfo(tensorflow::gtl::ArraySlice param_buffers, + TiledParameterInfo(absl::Span param_buffers, llvm::Value* y, llvm::Value* x) : param_buffers_(param_buffers), y_(y), x_(x) {} @@ -67,7 +67,7 @@ class TiledParameterInfo { private: // Param_buffers_[i] stores the tile buffer for the ith parameter or nullptr // if the parameter is not tiled. - tensorflow::gtl::ArraySlice param_buffers_; + absl::Span param_buffers_; // The y coordinate within a tile. llvm::Value* y_; // The x coordinate within a tile. diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc index 978fa5b453569687023c9867604f1be7ece4ee7a..219a9f221fbd116cdfbaf17985e21d82aefd079d 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc @@ -26,7 +26,6 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" namespace xla { @@ -36,8 +35,8 @@ ForLoop::ForLoop(absl::string_view prefix, absl::string_view suffix, llvm::Value* start_index, llvm::Value* end_index, llvm::Value* step, UnrollMode unroll_mode, bool prevent_vectorization) - : prefix_(std::string(prefix)), - suffix_(std::string(suffix)), + : prefix_(prefix), + suffix_(suffix), start_index_(start_index), end_index_(end_index), step_(step), @@ -242,7 +241,7 @@ IrArray::Index ForLoopNest::AddLoopsForShape(const Shape& shape, } IrArray::Index ForLoopNest::AddLoopsForShapeOnDimensions( - const Shape& shape, tensorflow::gtl::ArraySlice dimensions, + const Shape& shape, absl::Span dimensions, absl::string_view suffix) { llvm_ir::IrArray::Index index(index_type_, shape.dimensions_size()); for (int64 dimension : dimensions) { diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h index 62aa15fe2dc07dff622178477660a3cd9086d3ff..ac3bba3c9fd6a9eb4e7822474963fcc5a394baf7 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h @@ -21,13 +21,13 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Value.h" #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" @@ -183,7 +183,7 @@ class ForLoopNest { ForLoopNest(absl::string_view name, llvm::IRBuilder<>* b, llvm::Type* index_ty = nullptr) - : name_(std::string(name)), + : name_(name), outer_loop_preheader_bb_(nullptr), outer_loop_exit_bb_(nullptr), inner_loop_body_bb_(nullptr), @@ -242,7 +242,7 @@ class ForLoopNest { // size equals the rank of shape and there is a null for each // dimension that is not in "dimensions". IrArray::Index AddLoopsForShapeOnDimensions( - const Shape& shape, tensorflow::gtl::ArraySlice dimensions, + const Shape& shape, absl::Span dimensions, absl::string_view suffix); // Emits a series of nested loops for iterating over an operand array. Loops diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc index f0db2a3761afd3e887979d307fb3b9a557eea491..1a53c026be340ca3bec3a49b11666d6124728130 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc @@ -83,11 +83,10 @@ string DumpModuleToString(const llvm::Module& module) { return AsString(buffer_string); } -llvm::Value* EmitCallToIntrinsic( - llvm::Intrinsic::ID intrinsic_id, - tensorflow::gtl::ArraySlice operands, - tensorflow::gtl::ArraySlice overloaded_types, - llvm::IRBuilder<>* b) { +llvm::Value* EmitCallToIntrinsic(llvm::Intrinsic::ID intrinsic_id, + absl::Span operands, + absl::Span overloaded_types, + llvm::IRBuilder<>* b) { llvm::Module* module = ModuleFromIRBuilder(b); llvm::Function* intrinsic = llvm::Intrinsic::getDeclaration( module, intrinsic_id, AsArrayRef(overloaded_types)); diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h index dde50e19d1c77491fb843710ea765ecb2e8af932..f59baff263fe7184c6b0821c9dbd9eee205586a6 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h @@ -21,6 +21,7 @@ limitations under the License. #include #include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "llvm/ADT/StringRef.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/IRBuilder.h" @@ -33,7 +34,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/types.h" namespace llvm { @@ -59,7 +59,7 @@ llvm::ArrayRef AsArrayRef(const std::vector& vec) { } template -llvm::ArrayRef AsArrayRef(const tensorflow::gtl::ArraySlice& slice) { +llvm::ArrayRef AsArrayRef(const absl::Span& slice) { return llvm::ArrayRef(slice.data(), slice.size()); } @@ -101,11 +101,10 @@ string SanitizeFunctionName(string function_name); // intrinsics (for example, "minnum") must include a type in overloaded_types // for each overloaded type. Typically, overloaded intrinsics have only a single // overloaded type. -llvm::Value* EmitCallToIntrinsic( - llvm::Intrinsic::ID intrinsic_id, - tensorflow::gtl::ArraySlice operands, - tensorflow::gtl::ArraySlice overloaded_types, - llvm::IRBuilder<>* b); +llvm::Value* EmitCallToIntrinsic(llvm::Intrinsic::ID intrinsic_id, + absl::Span operands, + absl::Span overloaded_types, + llvm::IRBuilder<>* b); // Emit float max. Emit maxnum intrinsic is fast math is disabled, or // fcmp+select otherwise diff --git a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc index cf7445804c0c35f408139e5f815579f70a35b1ad..0dc120e0b0df47f261435f490a8459b49d989b53 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc @@ -18,13 +18,13 @@ limitations under the License. #include #include +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/types.h" @@ -69,7 +69,7 @@ static LoopEmitter::BodyEmitter MakeBodyEmitterForMultiOutputFusion( } LoopEmitter::LoopEmitter(const ElementGenerator& target_element_generator, - tensorflow::gtl::ArraySlice target_arrays, + absl::Span target_arrays, llvm::IRBuilder<>* b) : body_emitter_(MakeBodyEmitterForMultiOutputFusion( target_element_generator, @@ -105,7 +105,7 @@ std::vector LoopEmitter::EmitIndexAndSetExitBasicBlock( std::unique_ptr loop = loop_nest.AddLoop( /*start_index=*/0, /*end_index=*/shape_.dimensions(dimension), - /*suffix=*/tensorflow::strings::Printf("dim.%lld", dimension)); + /*suffix=*/absl::StrFormat("dim.%d", dimension)); array_index[dimension] = loop->GetIndVarValue(); } diff --git a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h index 57d9d8bbc61014d423822ab5c1e4d251349df89c..a537c00066b0a68404b142e91283510092b46e2d 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h +++ b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h @@ -53,8 +53,7 @@ class LoopEmitter { // This is used for multi-output fusion. target_element_generator must // produce an LLVM struct with N elements. LoopEmitter(const ElementGenerator& target_element_generator, - tensorflow::gtl::ArraySlice target_arrays, - llvm::IRBuilder<>* b); + absl::Span target_arrays, llvm::IRBuilder<>* b); LoopEmitter(const LoopEmitter&) = delete; LoopEmitter& operator=(const LoopEmitter&) = delete; diff --git a/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc b/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc index 00dd3f16389156afcf3824af0ce57763a82c0ad4..05ba4a40da413f0e774214e55ef69d023afc48e2 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc @@ -15,9 +15,11 @@ limitations under the License. #include "tensorflow/compiler/xla/service/llvm_ir/sort_util.h" +#include + // IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc" #include "absl/strings/string_view.h" -#include "absl/types/optional.h" +#include "llvm/ADT/APInt.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constants.h" #include "llvm/IR/Instructions.h" @@ -42,7 +44,7 @@ namespace { void EmitCompareLoop(int64 dimension_to_sort, const IrArray::Index& keys_index, const IrArray::Index& compare_keys_index, const IrArray& keys_array, - const absl::optional& values_array, + const std::vector& values_arrays, llvm::IRBuilder<>* b) { // if (is_smaller_index && // compare_keys[dimension_to_sort] < dimension_to_sort_bound) @@ -59,15 +61,39 @@ void EmitCompareLoop(int64 dimension_to_sort, const IrArray::Index& keys_index, SetToFirstInsertPoint(if_data.true_block, b); auto key1 = keys_array.EmitReadArrayElement(keys_index, b); auto key2 = keys_array.EmitReadArrayElement(compare_keys_index, b); + auto compare_key1 = key1; + auto compare_key2 = key2; auto key_type = keys_array.GetShape().element_type(); + bool is_signed_comparison = true; + if (primitive_util::IsFloatingPointType(key_type)) { + // We would like a total order of floating point numbers so that the sort + // has a predictable behavior in the presence of NaNs. Rather than using + // floating point comparison, we use the following trick: + // If f is a float, and + // x = bit_cast(f); + // y = x < 0 ? 0x7FFFFFFF - x : x; + // then y is ordered as an int32 such that finite values have the obvious + // order, -0 is ordered before 0, and -NaN and NaN appear at the beginning + // and end of the ordering. + auto k = b->getInt(llvm::APInt::getSignedMaxValue( + key1->getType()->getPrimitiveSizeInBits())); + auto comparison_type = k->getType(); + auto zero = llvm::ConstantInt::get(comparison_type, 0); + auto maybe_flip = [&](llvm::Value* v) { + return b->CreateSelect(b->CreateICmp(llvm::ICmpInst::ICMP_SLT, v, zero), + b->CreateSub(k, v), v); + }; + compare_key1 = b->CreateBitCast(key1, comparison_type); + compare_key2 = b->CreateBitCast(key2, comparison_type); + compare_key1 = maybe_flip(compare_key1); + compare_key2 = maybe_flip(compare_key2); + } else if (!primitive_util::IsSignedIntegralType(key_type)) { + is_signed_comparison = false; + } auto comparison = - primitive_util::IsFloatingPointType(key_type) - // TODO(b/26783907): Figure out how to handle NaNs. - ? b->CreateFCmp(llvm::FCmpInst::FCMP_ULT, key2, key1) - : b->CreateICmp(primitive_util::IsSignedIntegralType(key_type) - ? llvm::ICmpInst::ICMP_SLT - : llvm::ICmpInst::ICMP_ULT, - key2, key1); + b->CreateICmp(is_signed_comparison ? llvm::ICmpInst::ICMP_SLT + : llvm::ICmpInst::ICMP_ULT, + compare_key2, compare_key1); // If key2 < key1 auto if_smaller_data = EmitIfThenElse(comparison, "is_smaller_than", b, /*emit_else=*/false); @@ -75,19 +101,18 @@ void EmitCompareLoop(int64 dimension_to_sort, const IrArray::Index& keys_index, // Swap key1 with key2. keys_array.EmitWriteArrayElement(keys_index, key2, b); keys_array.EmitWriteArrayElement(compare_keys_index, key1, b); - if (values_array.has_value()) { + for (const auto& values_array : values_arrays) { // Also swap the values. - auto value1 = values_array.value().EmitReadArrayElement(keys_index, b); - auto value2 = - values_array.value().EmitReadArrayElement(compare_keys_index, b); - values_array.value().EmitWriteArrayElement(keys_index, value2, b); - values_array.value().EmitWriteArrayElement(compare_keys_index, value1, b); + auto value1 = values_array.EmitReadArrayElement(keys_index, b); + auto value2 = values_array.EmitReadArrayElement(compare_keys_index, b); + values_array.EmitWriteArrayElement(keys_index, value2, b); + values_array.EmitWriteArrayElement(compare_keys_index, value1, b); } } } // namespace Status EmitSortInPlace(int64 dimension_to_sort, const IrArray& keys_array, - const absl::optional& values_array, + const std::vector& values_arrays, absl::string_view name, llvm::Value* xor_mask, llvm::IRBuilder<>* b, const gpu::LaunchDimensions* launch_dimensions) { @@ -137,7 +162,7 @@ Status EmitSortInPlace(int64 dimension_to_sort, const IrArray& keys_array, compare_keys_index[dimension_to_sort] = b->CreateXor(compare_index[0], xor_mask); EmitCompareLoop(dimension_to_sort, keys_index, compare_keys_index, - keys_array, values_array, b); + keys_array, values_arrays, b); return Status::OK(); }; if (launch_dimensions != nullptr) { diff --git a/tensorflow/compiler/xla/service/llvm_ir/sort_util.h b/tensorflow/compiler/xla/service/llvm_ir/sort_util.h index 527ed10374ce9482045a8459e38fd041e0e83001..2f3bcda2307bcbb35a03b9e71dbbe44e366b3820 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/sort_util.h +++ b/tensorflow/compiler/xla/service/llvm_ir/sort_util.h @@ -16,8 +16,9 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_SORT_UTIL_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_SORT_UTIL_H_ +#include + #include "absl/strings/string_view.h" -#include "absl/types/optional.h" #include "llvm/IR/Value.h" #include "tensorflow/compiler/xla/service/gpu/partition_assignment.h" #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" @@ -31,7 +32,7 @@ namespace llvm_ir { // implements the inner loop of BitonicSort. If 'launch_dimensions' is nullptr, // the inner compare loop will not be parallelized. Status EmitSortInPlace(int64 dimension_to_sort, const IrArray& keys_array, - const absl::optional& values_array, + const std::vector& values_arrays, absl::string_view name, llvm::Value* xor_mask, llvm::IRBuilder<>* b, const gpu::LaunchDimensions* launch_dimensions); diff --git a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc index 11ed6ee59f1bf8e7004b8bef7319b37ef41a304c..a60643bc754f896d096b3ca4e1216e77d7e384c6 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc @@ -64,8 +64,7 @@ void EmitTupleSelect(const IrArray& select, const IrArray& pred, } } -void EmitTuple(const IrArray& tuple, - tensorflow::gtl::ArraySlice operands, +void EmitTuple(const IrArray& tuple, absl::Span operands, llvm::IRBuilder<>* b, llvm::Module* module) { for (size_t i = 0; i < operands.size(); ++i) { auto* store = b->CreateStore( @@ -76,6 +75,16 @@ void EmitTuple(const IrArray& tuple, } } +void EmitTuple(const IrArray& tuple, absl::Span buffers, + llvm::IRBuilder<>* b, llvm::Module* module) { + std::vector buffer_ptrs; + buffer_ptrs.reserve(buffers.size()); + absl::c_transform( + buffers, std::back_inserter(buffer_ptrs), + [](const llvm_ir::IrArray& buffer) { return buffer.GetBasePointer(); }); + llvm_ir::EmitTuple(tuple, buffer_ptrs, b, module); +} + llvm::Value* EmitGetTupleElement(const Shape& target_shape, int64 index, int alignment, llvm::Value* operand, llvm::IRBuilder<>* b, llvm::Module* module) { diff --git a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h index cf6bf5d0b14ba71cbed67f9a1dc728c0eef5e393..94340b91d8eeea1ba4681c2e49c0894eab2f6cc0 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h +++ b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h @@ -16,10 +16,10 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_TUPLE_OPS_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_TUPLE_OPS_H_ +#include "absl/types/span.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Value.h" #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/types.h" // Utilities for emitting LLVM IR related to HLO tuples. @@ -65,8 +65,12 @@ void EmitTupleSelect(const IrArray& select, const IrArray& pred, // A tuple is an array of pointers, one for each operand. Each pointer points to // the output buffer of its corresponding operand. -void EmitTuple(const IrArray& tuple, - tensorflow::gtl::ArraySlice operands, +void EmitTuple(const IrArray& tuple, absl::Span operands, + llvm::IRBuilder<>* b, llvm::Module* module); + +// Similar to EmitTuple above, except that the output buffers are provided in +// the form of IrArray. +void EmitTuple(const IrArray& tuple, absl::Span buffers, llvm::IRBuilder<>* b, llvm::Module* module); // A tuple is an array of pointers, one for each operand. Each pointer points to diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc index ea59adadea1277b265938468d7139ed50f8a08a7..cca37556173bb95ef062b59ab0a4bf9ca7c496fe 100644 --- a/tensorflow/compiler/xla/service/local_service.cc +++ b/tensorflow/compiler/xla/service/local_service.cc @@ -21,6 +21,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/client/executable_build_options.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/execution_options_util.h" @@ -140,16 +141,16 @@ ExecutionOptions CreateExecutionOptions( StatusOr> LocalService::CompileExecutable( const XlaComputation& computation, - const tensorflow::gtl::ArraySlice argument_layouts, + const absl::Span argument_layouts, const ExecutableBuildOptions& build_options) { const HloModuleProto& proto = computation.proto(); - TF_RET_CHECK(proto.has_program_shape()); - const ProgramShape& program_shape = proto.program_shape(); + TF_RET_CHECK(proto.has_host_program_shape()); + const ProgramShape& program_shape = proto.host_program_shape(); // Validate incoming layouts. if (argument_layouts.size() != program_shape.parameters_size()) { return InvalidArgument( - "Invalid number of arguments for computation: expected %d, got %zu.", + "Invalid number of arguments for computation: expected %d, got %u.", program_shape.parameters_size(), argument_layouts.size()); } @@ -167,16 +168,15 @@ StatusOr> LocalService::CompileExecutable( CHECK(metadata.value() != nullptr); const OpMetadata& m = *metadata.value(); if (!m.source_file().empty()) { - return tensorflow::strings::Printf( - " (%s:%d)", m.source_file().c_str(), m.source_line()); + return absl::StrFormat(" (%s:%d)", m.source_file(), m.source_line()); } return ""; }; return InvalidArgument( "Invalid argument shape for argument %d%s, expected %s, got %s.", i, - metadata_string().c_str(), - ShapeUtil::HumanString(program_shape.parameters(i)).c_str(), - ShapeUtil::HumanString(argument_shape).c_str()); + metadata_string(), + ShapeUtil::HumanString(program_shape.parameters(i)), + ShapeUtil::HumanString(argument_shape)); } } if (build_options.result_layout() != nullptr) { @@ -214,7 +214,7 @@ StatusOr LocalService::GlobalDataToShapedBuffer( TF_ASSIGN_OR_RETURN(auto buffers, allocation_tracker_.Resolve(data)); if (replica_number >= buffers.size()) { return InvalidArgument( - "replica_number %d out of range; must be less than num_replicas = %zu.", + "replica_number %d out of range; must be less than num_replicas = %u.", replica_number, buffers.size()); } return buffers[replica_number]; diff --git a/tensorflow/compiler/xla/service/local_service.h b/tensorflow/compiler/xla/service/local_service.h index 8f707ea9046a00a15cac469672a7a992f20bf483..3b4f0b50832d6d2b64528ffb63eb5c7375396aec 100644 --- a/tensorflow/compiler/xla/service/local_service.h +++ b/tensorflow/compiler/xla/service/local_service.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/client/executable_build_options.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/service/backend.h" @@ -28,7 +29,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/shaped_buffer.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" namespace xla { @@ -48,7 +48,7 @@ class LocalService : public Service { // compiler is responsible for freeing any memory it allocates this way. StatusOr> CompileExecutable( const XlaComputation& computation, - const tensorflow::gtl::ArraySlice argument_layouts, + const absl::Span argument_layouts, const ExecutableBuildOptions& build_options); // Returns the device ordinal that corresponds to the given replica number. diff --git a/tensorflow/compiler/xla/service/logical_buffer.h b/tensorflow/compiler/xla/service/logical_buffer.h index f9ba5a554740c9d4cc2643fe59d18ba76c30d03b..ceacab4ed7319527312a5a6ad715103b5bbaf40f 100644 --- a/tensorflow/compiler/xla/service/logical_buffer.h +++ b/tensorflow/compiler/xla/service/logical_buffer.h @@ -18,13 +18,13 @@ limitations under the License. #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/service/buffer_value.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/gtl/int_type.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" diff --git a/tensorflow/compiler/xla/service/logical_buffer_analysis.cc b/tensorflow/compiler/xla/service/logical_buffer_analysis.cc index eaa09591b72ee5202e0a9d1225d92eca92904adc..ec52a24d782a44fda961feab3230886072e755c7 100644 --- a/tensorflow/compiler/xla/service/logical_buffer_analysis.cc +++ b/tensorflow/compiler/xla/service/logical_buffer_analysis.cc @@ -54,7 +54,7 @@ Status LogicalBufferAnalysis::Analyze() { // so reserve 10% more than the number of instructions to avoid frequent // resizes. logical_buffers_.clear(); - logical_buffers_.reserve((module_->NumUniqueInstructionIds() * 11) / 10); + logical_buffers_.reserve((module_->instruction_count() * 11) / 10); // We filter out fusion computations, and get to them through fusion // instructions. This is because it's possible to have orphaned (unreachable) diff --git a/tensorflow/compiler/xla/service/inliner.cc b/tensorflow/compiler/xla/service/map_inliner.cc similarity index 75% rename from tensorflow/compiler/xla/service/inliner.cc rename to tensorflow/compiler/xla/service/map_inliner.cc index 5c193fceb984448cf0532d7e1010281268614293..2200ef054a6993fb884751643ab1fb5ab83efe05 100644 --- a/tensorflow/compiler/xla/service/inliner.cc +++ b/tensorflow/compiler/xla/service/map_inliner.cc @@ -13,11 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/inliner.h" +#include "tensorflow/compiler/xla/service/map_inliner.h" #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -27,15 +28,14 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/logging.h" namespace xla { -// InlinerVisitor traverses the HLO computation and inlines maps. -class InlinerVisitor : public DfsHloVisitorWithDefault { +// MapInlinerVisitor traverses the HLO computation and inlines maps. +class MapInlinerVisitor : public DfsHloVisitorWithDefault { public: - explicit InlinerVisitor(HloComputation* computation) + explicit MapInlinerVisitor(HloComputation* computation) : computation_(computation) {} // Default visitor action is to do nothing and return OK. @@ -49,48 +49,44 @@ class InlinerVisitor : public DfsHloVisitorWithDefault { StatusOr Run(HloComputation* computation); private: - // Current HloComputation instance the InlinerVisitor is traversing. + // Current HloComputation instance the MapInlinerVisitor is traversing. HloComputation* computation_; // Whether algebraic simplification has occurred. bool changed_ = false; }; -StatusOr InlinerVisitor::Run(HloComputation* computation) { +StatusOr MapInlinerVisitor::Run(HloComputation* computation) { changed_ = false; computation_ = computation; TF_RETURN_IF_ERROR(computation->root_instruction()->Accept(this)); return changed_; } -Status InlinerVisitor::HandleMap(HloInstruction* map) { +Status MapInlinerVisitor::HandleMap(HloInstruction* map) { HloComputation* function = map->to_apply(); HloInstruction& root = *function->root_instruction(); - // TODO(b/29249531): Add DCE pass to remove unused HloComputations. // Only inlining functions that are simply a single operation until a better // profitability model for inlining is defined. if (hlo_query::AllOperandsAreParameters(root)) { if (root.opcode() == HloOpcode::kFusion || - root.opcode() == HloOpcode::kParameter || root.opcode() == HloOpcode::kTrace) { // Cloning not supported for these instructions. return Status::OK(); } VLOG(10) << "inlining map({X ... Y}, op) => : op(X ... Y) with function " << root.ToShortString(); - // If the input is a constant then the shape of the constant could be - // different than the map shape. Hence, a broadcast is needed, else the - // cloned operand with new shape and operands work. - if (root.opcode() != HloOpcode::kConstant) { - std::vector params; - for (int64 o = 0; o < root.operands().size(); o++) { - params.push_back(map->operands()[root.operand(o)->parameter_number()]); - } - HloInstruction* placed_instruction = computation_->AddInstruction( - root.CloneWithNewOperands(map->shape(), params)); + if (root.opcode() == HloOpcode::kParameter) { + // If the root is a parameter, then use the corresponding operand as the + // result of the computation. TF_RETURN_IF_ERROR( - computation_->ReplaceInstruction(map, placed_instruction)); - } else { + map->ReplaceAllUsesWith(map->operands()[root.parameter_number()])); + TF_RETURN_IF_ERROR(computation_->RemoveInstruction(map)); + } else if (root.opcode() == HloOpcode::kConstant) { + // If the input is a constant then the shape of the constant could be + // different than the map shape. Hence, a broadcast is needed, else the + // cloned operand with new shape and operands work. + // // The constant is in an embedded computation and needs to be recreated // as part of the computation that the broadcast is inserted into. HloInstruction* constant = computation_->AddInstruction(root.Clone()); @@ -98,6 +94,15 @@ Status InlinerVisitor::HandleMap(HloInstruction* map) { HloInstruction::CreateBroadcast(map->shape(), constant, {})); TF_RETURN_IF_ERROR( computation_->ReplaceInstruction(map, placed_instruction)); + } else { + std::vector params; + for (int64 o = 0; o < root.operands().size(); o++) { + params.push_back(map->operands()[root.operand(o)->parameter_number()]); + } + HloInstruction* placed_instruction = computation_->AddInstruction( + root.CloneWithNewOperands(map->shape(), params)); + TF_RETURN_IF_ERROR( + computation_->ReplaceInstruction(map, placed_instruction)); } changed_ = true; return Status::OK(); @@ -106,8 +111,8 @@ Status InlinerVisitor::HandleMap(HloInstruction* map) { return Status::OK(); } -StatusOr Inliner::Run(HloModule* module) { - InlinerVisitor visitor(/*computation=*/nullptr); +StatusOr MapInliner::Run(HloModule* module) { + MapInlinerVisitor visitor(/*computation=*/nullptr); bool changed = false; for (HloComputation* computation : module->computations()) { TF_ASSIGN_OR_RETURN(bool computation_changed, visitor.Run(computation)); diff --git a/tensorflow/compiler/xla/service/inliner.h b/tensorflow/compiler/xla/service/map_inliner.h similarity index 59% rename from tensorflow/compiler/xla/service/inliner.h rename to tensorflow/compiler/xla/service/map_inliner.h index efa8ed3abcc6cd7cd8d31ec2170eae8752988c09..b67911811846e2250068921ef252b7df596d4016 100644 --- a/tensorflow/compiler/xla/service/inliner.h +++ b/tensorflow/compiler/xla/service/map_inliner.h @@ -13,27 +13,27 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_INLINER_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_INLINER_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_MAP_INLINER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_MAP_INLINER_H_ #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" namespace xla { -// A pass which performs inlining. Which can result, for example, in functions -// that were previously being mapped by Map instead directly applied to the -// forwarded operands (i.e., map({X, Y}, max) -> max(X, Y)). -class Inliner : public HloPassInterface { +// A pass which performs map inlining. This replaces kMap instructions with +// their equivalent sequence of array operations. For example: +// map({X, Y}, add) -> add(X, Y)). +class MapInliner : public HloModulePass { public: - ~Inliner() override = default; - absl::string_view name() const override { return "inline"; } + ~MapInliner() override = default; + absl::string_view name() const override { return "map-inline"; } - // Run inlining on the given computation. Returns whether the computation was - // changed. + // Run map inlining on the given computation. Returns whether the computation + // was changed. StatusOr Run(HloModule* module) override; }; } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_INLINER_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_MAP_INLINER_H_ diff --git a/tensorflow/compiler/xla/service/inliner_test.cc b/tensorflow/compiler/xla/service/map_inliner_test.cc similarity index 71% rename from tensorflow/compiler/xla/service/inliner_test.cc rename to tensorflow/compiler/xla/service/map_inliner_test.cc index 5695bc242057c037a1999e7d63f5b4f21b5f658a..84059dd0f71ee8fc0a25703cbab2268d7dc149a8 100644 --- a/tensorflow/compiler/xla/service/inliner_test.cc +++ b/tensorflow/compiler/xla/service/map_inliner_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/inliner.h" +#include "tensorflow/compiler/xla/service/map_inliner.h" #include #include @@ -26,7 +26,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -35,10 +35,10 @@ namespace op = xla::testing::opcode_matchers; namespace xla { namespace { -using InlinerTest = HloTestBase; +using MapInlinerTest = HloVerifiedTestBase; // Test that `map` with `max` is transformed to `max` -TEST_F(InlinerTest, MapMax) { +TEST_F(MapInlinerTest, MapMax) { Shape r0f32 = ShapeUtil::MakeShape(F32, {}); auto max_builder = HloComputation::Builder(TestName()); @@ -63,19 +63,19 @@ TEST_F(InlinerTest, MapMax) { hlo_module->AddEmbeddedComputation(std::move(max_f32)); hlo_module->AddEntryComputation(std::move(computation)); - Inliner inliner; - EXPECT_TRUE(inliner.Run(hlo_module.get()).ValueOrDie()); + MapInliner inliner; + EXPECT_TRUE(inliner.Run(hlo_module).ValueOrDie()); EXPECT_THAT(hlo_module->entry_computation()->root_instruction(), op::Maximum(lhs, rhs)); // Verify execution on CPU. - auto result = ExecuteAndTransfer(std::move(hlo_module), {}); + auto result = ExecuteAndTransfer(hlo_module->Clone(), {}); auto expected = LiteralUtil::CreateR1({4, 3, 3, 4}); - EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected)); + EXPECT_TRUE(LiteralTestUtil::Equal(result, expected)); } // Test that `constant` function is changed to `broadcast`. -TEST_F(InlinerTest, MapConstant) { +TEST_F(MapInlinerTest, MapConstant) { Shape r0f32 = ShapeUtil::MakeShape(F32, {}); auto const2_builder = HloComputation::Builder(TestName()); @@ -97,18 +97,18 @@ TEST_F(InlinerTest, MapConstant) { hlo_module->AddEmbeddedComputation(std::move(const2_f32)); hlo_module->AddEntryComputation(std::move(computation)); HloInstruction* root = hlo_module->entry_computation()->root_instruction(); - Inliner inliner; - EXPECT_TRUE(inliner.Run(hlo_module.get()).ValueOrDie()); + MapInliner inliner; + EXPECT_TRUE(inliner.Run(hlo_module).ValueOrDie()); root = hlo_module->entry_computation()->root_instruction(); EXPECT_THAT(root, op::Broadcast(op::Constant())); // Verify execution on CPU. - auto result = ExecuteAndTransfer(std::move(hlo_module), {}); + auto result = ExecuteAndTransfer(hlo_module->Clone(), {}); auto expected = LiteralUtil::CreateR2({{2, 2, 2, 2}, {2, 2, 2, 2}}); - EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected)); + EXPECT_TRUE(LiteralTestUtil::Equal(result, expected)); } -TEST_F(InlinerTest, MapSubtractOppositeOrder) { +TEST_F(MapInlinerTest, MapSubtractOppositeOrder) { Shape r0f32 = ShapeUtil::MakeShape(F32, {}); // Note that the parameter ordinals are in the opposite order to their @@ -135,17 +135,47 @@ TEST_F(InlinerTest, MapSubtractOppositeOrder) { hlo_module->AddEmbeddedComputation(std::move(max_f32)); hlo_module->AddEntryComputation(std::move(computation)); - Inliner inliner; - EXPECT_TRUE(inliner.Run(hlo_module.get()).ValueOrDie()); + MapInliner inliner; + EXPECT_TRUE(inliner.Run(hlo_module).ValueOrDie()); EXPECT_THAT(hlo_module->entry_computation()->root_instruction(), op::Subtract(rhs, lhs)); // Verify execution on CPU. - auto result = ExecuteAndTransfer(std::move(hlo_module), {}); + auto result = ExecuteAndTransfer(hlo_module->Clone(), {}); auto expected = LiteralUtil::CreateR1({3, 1, -1, -3}); - EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected)); + EXPECT_TRUE(LiteralTestUtil::Equal(result, expected)); } +TEST_F(MapInlinerTest, MapParameter) { + Shape r0f32 = ShapeUtil::MakeShape(F32, {}); + + auto param_builder = HloComputation::Builder(TestName()); + param_builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32, "p0")); + param_builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32, "p1")); + auto param_f32 = param_builder.Build(); + + auto builder = HloComputation::Builder("MapParamFunction"); + auto lhs = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); + auto rhs = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(4))); + builder.AddInstruction( + HloInstruction::CreateMap(lhs->shape(), {lhs, rhs}, param_f32.get())); + + auto computation = builder.Build(); + auto hlo_module = CreateNewVerifiedModule(); + hlo_module->AddEmbeddedComputation(std::move(param_f32)); + hlo_module->AddEntryComputation(std::move(computation)); + + MapInliner inliner; + EXPECT_TRUE(inliner.Run(hlo_module.get()).ValueOrDie()); + EXPECT_THAT(hlo_module->entry_computation()->root_instruction(), rhs); + + // Verify execution on CPU. + auto result = ExecuteAndTransfer(hlo_module->Clone(), {}); + auto expected = LiteralUtil::CreateR0(4); + EXPECT_TRUE(LiteralTestUtil::Equal(result, expected)); +} } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/maybe_owning_device_memory.cc b/tensorflow/compiler/xla/service/maybe_owning_device_memory.cc new file mode 100644 index 0000000000000000000000000000000000000000..8269842426e3ee15ea974098a43fe7752c7614df --- /dev/null +++ b/tensorflow/compiler/xla/service/maybe_owning_device_memory.cc @@ -0,0 +1,41 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h" +#include "absl/types/variant.h" +namespace xla { + +se::DeviceMemoryBase MaybeOwningDeviceMemory::AsDeviceMemoryBase() { + if (HasOwnership()) { + return absl::get(mem_).AsDeviceMemoryBase(); + } else { + return absl::get(mem_); + } +} + +bool MaybeOwningDeviceMemory::HasOwnership() const { + return absl::holds_alternative(mem_); +} + +absl::optional MaybeOwningDeviceMemory::Release() { + if (!HasOwnership()) { + return {}; + } + OwningDeviceMemory result = std::move(absl::get(mem_)); + mem_ = result.AsDeviceMemoryBase(); + return absl::make_optional(std::move(result)); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/maybe_owning_device_memory.h b/tensorflow/compiler/xla/service/maybe_owning_device_memory.h new file mode 100644 index 0000000000000000000000000000000000000000..82e7f1183c086437e10daea85ea99235b06cbb35 --- /dev/null +++ b/tensorflow/compiler/xla/service/maybe_owning_device_memory.h @@ -0,0 +1,70 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_MAYBE_OWNING_DEVICE_MEMORY_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_MAYBE_OWNING_DEVICE_MEMORY_H_ + +#include "absl/types/optional.h" +#include "absl/types/variant.h" +#include "tensorflow/compiler/xla/service/device_memory_allocator.h" +#include "tensorflow/compiler/xla/service/owning_device_memory.h" + +namespace xla { + +// MaybeOwningDeviceMemory represents either an owned or unowned device memory. +// Like std::variant. When the object goes +// output of scope, it will free the underlying memory if it owns it. +class MaybeOwningDeviceMemory { + public: + MaybeOwningDeviceMemory() = default; + explicit MaybeOwningDeviceMemory(OwningDeviceMemory owned) + : mem_(std::move(owned)) {} + explicit MaybeOwningDeviceMemory(se::DeviceMemoryBase unowned) + : mem_(unowned) {} + MaybeOwningDeviceMemory(MaybeOwningDeviceMemory&&) = default; + ~MaybeOwningDeviceMemory() = default; + + MaybeOwningDeviceMemory& operator=(se::DeviceMemoryBase unowned) { + mem_ = unowned; + return *this; + } + + MaybeOwningDeviceMemory& operator=(OwningDeviceMemory owned) { + mem_ = std::move(owned); + return *this; + } + + MaybeOwningDeviceMemory& operator=(MaybeOwningDeviceMemory&&) = default; + + // Fetches the underlying DeviceMemoryBase from a MaybeOwningDeviceMemory. The + // caller of this function is *not* responsible for freeing the memory. + se::DeviceMemoryBase AsDeviceMemoryBase(); + + // Release the OwningDeviceMemory without freeing it, and moves the ownership + // of the memory buffer from the object to the caller. + // + // A nullopt is returned if the HasOwnership() == false; + absl::optional Release(); + + // Returns true if the device_memory has ownership over underlying memory. + bool HasOwnership() const; + + private: + absl::variant mem_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_MAYBE_OWNING_DEVICE_MEMORY_H_ diff --git a/tensorflow/compiler/xla/service/multi_output_fusion.cc b/tensorflow/compiler/xla/service/multi_output_fusion.cc index 4166ef5baf9c891968b584a0c498005e9ae87784..2ca527bc4cb8f66a085c1e6a7cbb8ddaedbfc07e 100644 --- a/tensorflow/compiler/xla/service/multi_output_fusion.cc +++ b/tensorflow/compiler/xla/service/multi_output_fusion.cc @@ -15,10 +15,10 @@ limitations under the License. #include "tensorflow/compiler/xla/service/multi_output_fusion.h" +#include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -50,7 +50,7 @@ StatusOr MultiOutputFusion::Run(HloModule* module) { all_fusion_candidates_.push_back(instruction); std::vector candidates; - tensorflow::gtl::FlatSet candidates_set; + absl::flat_hash_set candidates_set; VLOG(10) << "Looking at instruction: " << instruction->name(); for (auto operand : instruction->operands()) { // Filter out the non-interesting instructions -- they @@ -172,7 +172,7 @@ void MultiOutputFusion::Update(HloInstruction* instr1, HloInstruction* instr2) { // Update the fusible list for fusion. Variable new_fusibles keeps // track of the new or changed entries. std::vector> new_fusibles; - tensorflow::gtl::FlatSet in_list; + absl::flat_hash_set in_list; auto it = fusion_node.fusibles.begin(); while (it != fusion_node.fusibles.end()) { HloInstruction* instr = it->first; @@ -262,7 +262,7 @@ void MultiOutputFusion::RecomputeReachability() { void MultiOutputFusion::UpdateReachability( HloInstruction* instr1, HloInstruction* instr2, - tensorflow::gtl::ArraySlice instrs_to_update, + absl::Span instrs_to_update, const std::function& skip) { for (auto instr : instrs_to_update) { if (skip != nullptr && skip(instr)) { diff --git a/tensorflow/compiler/xla/service/multi_output_fusion.h b/tensorflow/compiler/xla/service/multi_output_fusion.h index 4c8cb7d379d4f82224ef5896fbd937d4aa482606..9508ab2ed1d38ec40983d8892ec8875b848fb21b 100644 --- a/tensorflow/compiler/xla/service/multi_output_fusion.h +++ b/tensorflow/compiler/xla/service/multi_output_fusion.h @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" #include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" @@ -44,7 +45,7 @@ namespace xla { // Note that the reachability map is updated based on the original computation. // This works because the reachability is monotonically increasing with // instruction fusion. -class MultiOutputFusion : public HloPassInterface { +class MultiOutputFusion : public HloModulePass { public: MultiOutputFusion(int64 fuel) : fuel_(fuel) {} @@ -92,7 +93,7 @@ class MultiOutputFusion : public HloPassInterface { // Update the reachability map after fusing instr1 and instr2. void UpdateReachability( HloInstruction* instr1, HloInstruction* instr2, - tensorflow::gtl::ArraySlice instrs_to_update, + absl::Span instrs_to_update, const std::function& skip = nullptr); // Hook for multi-output fusion along producer-consumer edges. @@ -126,7 +127,7 @@ class MultiOutputFusion : public HloPassInterface { std::vector candidates_; // A map that maps an instruction to the index_. - tensorflow::gtl::FlatMap candidates_index_; + absl::flat_hash_map candidates_index_; // The reachability map of current computation. std::unique_ptr reachability_; diff --git a/tensorflow/compiler/xla/service/name_uniquer.cc b/tensorflow/compiler/xla/service/name_uniquer.cc index 70cd0a339a4da54ede7b709a1ce5de254b530577..ac2f79674feceff436c0e9c65338967f498e4473 100644 --- a/tensorflow/compiler/xla/service/name_uniquer.cc +++ b/tensorflow/compiler/xla/service/name_uniquer.cc @@ -39,8 +39,10 @@ NameUniquer::NameUniquer(const string& separator) { } /*static*/ string NameUniquer::GetSanitizedName(const string& name) { + if (name.empty()) { + return ""; + } string result = name; - CHECK(!result.empty()) << "name should not be empty"; char c = static_cast(result[0]); if (!isalpha(c) && c != '_') { result[0] = '_'; @@ -54,7 +56,7 @@ NameUniquer::NameUniquer(const string& separator) { } string NameUniquer::GetUniqueName(absl::string_view prefix) { - string root = GetSanitizedName(prefix.empty() ? "name" : std::string(prefix)); + string root = GetSanitizedName(prefix.empty() ? "name" : string(prefix)); // Strip away numeric suffix (if any). Only recognize separator if it is in // the middle of the name. diff --git a/tensorflow/compiler/xla/service/name_uniquer.h b/tensorflow/compiler/xla/service/name_uniquer.h index 6dd89c240f81c9f0ccac66e50c7f244bfd5429f1..8909d0f4fea801e43ab06a75e8933d24a74146bc 100644 --- a/tensorflow/compiler/xla/service/name_uniquer.h +++ b/tensorflow/compiler/xla/service/name_uniquer.h @@ -18,10 +18,10 @@ limitations under the License. #include +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/gtl/flatmap.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/macros.h" namespace xla { @@ -69,7 +69,7 @@ class NameUniquer { int64 next_ = 0; // Set of all the identifiers which has been used. - tensorflow::gtl::FlatSet used_; + absl::flat_hash_set used_; }; // The string to use to separate the prefix of the name from the uniquing @@ -78,7 +78,7 @@ class NameUniquer { // Map from name prefix to the generator data structure which tracks used // identifiers and generates new ones. - tensorflow::gtl::FlatMap generated_names_; + absl::flat_hash_map generated_names_; TF_DISALLOW_COPY_AND_ASSIGN(NameUniquer); }; diff --git a/tensorflow/compiler/xla/service/pattern_matcher.h b/tensorflow/compiler/xla/service/pattern_matcher.h index ccc06ce613cb133d0be982bbb58bbc64d42a27c1..380cde0e6a858c7800445be94bb08dc22f3e776a 100644 --- a/tensorflow/compiler/xla/service/pattern_matcher.h +++ b/tensorflow/compiler/xla/service/pattern_matcher.h @@ -17,8 +17,12 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_PATTERN_MATCHER_H_ #include "absl/strings/string_view.h" +#include "absl/utility/utility.h" #include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -116,15 +120,82 @@ namespace xla { // .WithOperand(1, Op(&c)) // .WithOperand(2, Op(&d)) // + +struct MatchOption { + // If true, actually capture matched item into the user pointer. + bool capture; +}; + template -bool Match(Value* value, const Pattern& pattern) { - return pattern.Match(value); +bool Match(Value* value, const Pattern& pattern, + MatchOption option = {/*.capture=*/true}) { + if (option.capture) { + auto new_option = option; + new_option.capture = false; + if (!pattern.Match(value, new_option)) { + return false; + } + } + return pattern.Match(value, option); } namespace match { namespace detail { +template +class AllOfPattern { + public: + explicit AllOfPattern(const Patterns&... patterns) : patterns_(patterns...) {} + + bool Match(const Item* item, MatchOption option) const { + bool matched = MatchImpl(item, option, std::integral_constant()); + // This invariant is guaranteed by the top-level Match and AnyOf. + DCHECK(matched || !option.capture); + return matched; + } + + bool Match(Item* item, MatchOption option) const { + bool matched = MatchImpl(item, option, std::integral_constant()); + // This invariant is guaranteed by the top-level Match and AnyOf. + DCHECK(matched || !option.capture); + return matched; + } + + private: + template + bool MatchImpl(ItemType* item, MatchOption option, + std::integral_constant) const { + return std::get(patterns_).Match(item, option) && + MatchImpl(item, option, std::integral_constant()); + } + + template + bool MatchImpl(ItemType* item, MatchOption option, + std::integral_constant) const { + return true; + } + + std::tuple patterns_; +}; + +} // namespace detail + +// Returns a pattern that represents the conjunction of all input patterns. All +// patterns need to match in order to have the AllOf pattern match. +// +// TODO(timshen): Currently AllOf is still nested, e.g. AllOf, B> is +// not AllOf. We might want to flatten the AllOf type structure if the +// C++ compile error message gets annoying. +template +detail::AllOfPattern::type, Patterns...> AllOf( + const Patterns&... patterns) { + return detail::AllOfPattern::type, + Patterns...>(patterns...); +} + +namespace detail { + template class LayoutPattern; @@ -132,57 +203,61 @@ class LayoutPattern; // nullptr. class LayoutPatternBaseImpl { public: - bool Match(const ::xla::Layout* layout) const { return layout != nullptr; } + bool Match(const ::xla::Layout* layout, MatchOption option) const { + return layout != nullptr; + } }; // A LayoutPattern implementation that matches only if the layout equals a // Layout proto. -template class LayoutPatternEqualImpl { public: - explicit constexpr LayoutPatternEqualImpl(const Previous& previous, - const ::xla::Layout* layout) - : previous_(previous), layout_(layout) {} + explicit constexpr LayoutPatternEqualImpl(const ::xla::Layout* layout) + : layout_(layout) {} - bool Match(const ::xla::Layout* layout) const { - return previous_.Match(layout) && LayoutUtil::Equal(*layout_, *layout); + bool Match(const ::xla::Layout* layout, MatchOption option) const { + return LayoutUtil::Equal(*layout_, *layout); } private: - Previous previous_; const ::xla::Layout* layout_; }; // A LayoutPattern implementation that matches only if the layout has a given // format. -template class LayoutPatternFormatImpl { public: - explicit constexpr LayoutPatternFormatImpl(const Previous& previous, - Format format) - : previous_(previous), format_(format) {} + explicit constexpr LayoutPatternFormatImpl(Format format) : format_(format) {} - bool Match(const ::xla::Layout* layout) const { - return previous_.Match(layout) && layout->format() == format_; + bool Match(const ::xla::Layout* layout, MatchOption option) const { + return layout->format() == format_; } private: - Previous previous_; Format format_; }; // A pattern that matches Layouts. template class LayoutPattern { + private: + template + LayoutPattern> + AppendImpl(NewImpl new_impl) const { + return LayoutPattern>( + AllOf(impl_, std::move(new_impl)), matched_layout_); + } + public: explicit constexpr LayoutPattern(const Impl& impl, LayoutType** matched_layout) : impl_(impl), matched_layout_(matched_layout) {} // Returns true and captures the layout iff it matches the pattern. - bool Match(const ::xla::Layout* layout) const { - if (impl_.Match(layout)) { - if (matched_layout_) { + bool Match(const ::xla::Layout* layout, MatchOption option) const { + if (impl_.Match(layout, option)) { + if (option.capture && matched_layout_) { *matched_layout_ = layout; } return true; @@ -191,9 +266,9 @@ class LayoutPattern { } // Returns true and captures the layout iff it matches the pattern. - bool Match(::xla::Layout* layout) const { - if (impl_.Match(layout)) { - if (matched_layout_) { + bool Match(::xla::Layout* layout, MatchOption option) const { + if (impl_.Match(layout, option)) { + if (option.capture && matched_layout_) { *matched_layout_ = layout; } return true; @@ -203,24 +278,21 @@ class LayoutPattern { // Modifies the pattern to match only if the layout equals the given proto. // The layout must outlive the returned pattern. - constexpr LayoutPattern> EqualTo( - const ::xla::Layout* layout) const { - return LayoutPattern>( - LayoutPatternEqualImpl(impl_, layout), matched_layout_); + constexpr auto EqualTo(const ::xla::Layout* layout) const + -> decltype(this->AppendImpl(LayoutPatternEqualImpl(layout))) { + return AppendImpl(LayoutPatternEqualImpl(layout)); } // Modifies the pattern to match only if the layout has a dense format. - constexpr LayoutPattern> - WithDenseFormat() const { - return LayoutPattern>( - LayoutPatternFormatImpl(impl_, DENSE), matched_layout_); + constexpr auto WithDenseFormat() const + -> decltype(this->AppendImpl(LayoutPatternFormatImpl(DENSE))) { + return AppendImpl(LayoutPatternFormatImpl(DENSE)); } // Modifies the pattern to match only if the layout has a sparse format. - constexpr LayoutPattern> - WithSparseFormat() const { - return LayoutPattern>( - LayoutPatternFormatImpl(impl_, SPARSE), matched_layout_); + constexpr auto WithSparseFormat() const + -> decltype(this->AppendImpl(LayoutPatternFormatImpl(SPARSE))) { + return AppendImpl(LayoutPatternFormatImpl(SPARSE)); } private: @@ -228,8 +300,72 @@ class LayoutPattern { LayoutType** matched_layout_; }; +template +class AnyOfPattern { + public: + explicit AnyOfPattern(const Patterns&... patterns) : patterns_(patterns...) {} + + bool Match(const Item* item, MatchOption option) const { + return MatchImpl(item, option, std::integral_constant()); + } + + bool Match(Item* item, MatchOption option) const { + return MatchImpl(item, option, std::integral_constant()); + } + + private: + template + bool MatchImpl(ItemType* item, MatchOption option, + std::integral_constant) const { + auto new_option = option; + new_option.capture = false; + // Try to match the sub-pattern without capturing behavior. + if (std::get(patterns_).Match(item, new_option)) { + // Capture the branch. + if (option.capture) { + // TODO(timshen): Currently the behavior can be exponential. Optimize it + // with memoization or recording the matched sub-pattern index, if it + // takes too long to run. + // + // Specifically, the "memoization" approach is to create an empty + // container with the key (pattern, instruction), and value as whether + // matched or not. + // + // Alternatively, we may run the pattern matching with captures off, but + // instead record a "trace" somewhere, indicating how exactly the + // pattern matches the input. For example, the trace information for + // AnyOf will be a runtime number indicate which sub-pattern is matched. + // Then we run another pass to do captures only with the help of the + // trace. + bool ret = std::get(patterns_).Match(item, option); + DCHECK(ret); + } + return true; + } + return MatchImpl(item, option, std::integral_constant()); + } + + template + bool MatchImpl(ItemType* item, MatchOption option, + std::integral_constant) const { + return false; + } + + std::tuple patterns_; +}; + } // namespace detail +// Returns a pattern that represents the logical disjunction of the input +// patterns. The returned pattern matches from left to right, and stops on the +// first match. +template +detail::AnyOfPattern::type, Patterns...> AnyOf( + const Patterns&... patterns) { + return detail::AnyOfPattern::type, + Patterns...>(patterns...); +} + // Creates a layout pattern that will capture the matched layout in the // argument. inline constexpr detail::LayoutPattern class ShapePatternEqualImpl { public: - explicit constexpr ShapePatternEqualImpl(const Previous& previous, - const ::xla::Shape* shape) - : previous_(previous), shape_(shape) {} + explicit constexpr ShapePatternEqualImpl(const ::xla::Shape* shape) + : shape_(shape) {} - bool Match(const ::xla::Shape* shape) const { - return previous_.Match(shape) && ShapeUtil::Equal(*shape_, *shape); + bool Match(const ::xla::Shape* shape, MatchOption option) const { + return ShapeUtil::Equal(*shape_, *shape); } private: - Previous previous_; const ::xla::Shape* shape_; }; // A ShapePattern implementation that matches only if the shape is compatible to // a Shape proto. -template class ShapePatternCompatibleImpl { public: - explicit constexpr ShapePatternCompatibleImpl(const Previous& previous, - const ::xla::Shape* shape) - : previous_(previous), shape_(shape) {} + explicit constexpr ShapePatternCompatibleImpl(const ::xla::Shape* shape) + : shape_(shape) {} - bool Match(const ::xla::Shape* shape) const { - return previous_.Match(shape) && ShapeUtil::Compatible(*shape_, *shape); + bool Match(const ::xla::Shape* shape, MatchOption option) const { + return ShapeUtil::Compatible(*shape_, *shape); } private: - Previous previous_; const ::xla::Shape* shape_; }; // A ShapePattern implementation that matches only if the shape has a given // element type. -template class ShapePatternElementTypeImpl { public: - explicit constexpr ShapePatternElementTypeImpl(const Previous& previous, - PrimitiveType element_type) - : previous_(previous), element_type_(element_type) {} + explicit constexpr ShapePatternElementTypeImpl(PrimitiveType element_type) + : element_type_(element_type) {} - bool Match(const ::xla::Shape* shape) const { - return previous_.Match(shape) && shape->element_type() == element_type_; + bool Match(const ::xla::Shape* shape, MatchOption option) const { + return shape->element_type() == element_type_; } private: - Previous previous_; PrimitiveType element_type_; }; // A ShapePattern implementation that matches only if the shape is scalar. -template class ShapePatternIsScalarImpl { public: - explicit constexpr ShapePatternIsScalarImpl(const Previous& previous) - : previous_(previous) {} + explicit constexpr ShapePatternIsScalarImpl() {} - bool Match(const ::xla::Shape* shape) const { - return previous_.Match(shape) && ShapeUtil::IsScalar(*shape); + bool Match(const ::xla::Shape* shape, MatchOption option) const { + return ShapeUtil::IsScalar(*shape); } - - private: - Previous previous_; }; // A ShapePattern implementation that matches only if the shape is an array -template class ShapePatternIsArrayImpl { public: - explicit constexpr ShapePatternIsArrayImpl(const Previous& previous) - : previous_(previous) {} + explicit constexpr ShapePatternIsArrayImpl() {} - bool Match(const ::xla::Shape* shape) const { - return previous_.Match(shape) && ShapeUtil::IsArray(*shape); + bool Match(const ::xla::Shape* shape, MatchOption option) const { + return ShapeUtil::IsArray(*shape); } - - private: - Previous previous_; }; // A ShapePattern implementation that matches only if the shape is a tuple. -template class ShapePatternIsTupleImpl { public: - explicit constexpr ShapePatternIsTupleImpl(const Previous& previous) - : previous_(previous) {} + explicit constexpr ShapePatternIsTupleImpl() {} - bool Match(const ::xla::Shape* shape) const { - return previous_.Match(shape) && ShapeUtil::IsTuple(*shape); + bool Match(const ::xla::Shape* shape, MatchOption option) const { + return ShapeUtil::IsTuple(*shape); } - - private: - Previous previous_; }; // A ShapePattern implementation that matches only if the shape has a given // rank. -template class ShapePatternRankImpl { public: - explicit constexpr ShapePatternRankImpl(const Previous& previous, int64 rank) - : previous_(previous), rank_(rank) {} + explicit constexpr ShapePatternRankImpl(int64 rank) : rank_(rank) {} - bool Match(const ::xla::Shape* shape) const { - return previous_.Match(shape) && ShapeUtil::Rank(*shape) == rank_; + bool Match(const ::xla::Shape* shape, MatchOption option) const { + return ShapeUtil::Rank(*shape) == rank_; } private: - Previous previous_; int64 rank_; }; // A ShapePattern implementation that matches only if the shape has a layout // that matches a given pattern. -template +template class ShapePatternLayoutImpl { public: explicit constexpr ShapePatternLayoutImpl( - const Previous& previous, const LayoutPattern& layout) - : previous_(previous), layout_(layout) {} + : layout_(layout) {} - bool Match(const ::xla::Shape* shape) const { - return previous_.Match(shape) && LayoutUtil::HasLayout(*shape) && - layout_.Match(&shape->layout()); + bool Match(const ::xla::Shape* shape, MatchOption option) const { + return LayoutUtil::HasLayout(*shape) && + layout_.Match(&shape->layout(), option); } - bool Match(Shape* shape) const { - return previous_.Match(shape) && LayoutUtil::HasLayout(*shape) && - layout_.Match(shape->mutable_layout()); + bool Match(Shape* shape, MatchOption option) const { + return LayoutUtil::HasLayout(*shape) && + layout_.Match(shape->mutable_layout(), option); } private: - Previous previous_; LayoutPattern layout_; }; // A ShapePattern implementation that matches only if the shape has a subshape // that matches a given pattern. -template +template class ShapePatternSubshapeImpl { public: explicit ShapePatternSubshapeImpl( - const Previous& previous, ShapeIndexView index, + ShapeIndexView index, const ShapePattern& subshape) - : previous_(previous), index_(index), subshape_(subshape) {} + : index_(index), subshape_(subshape) {} - bool Match(const ::xla::Shape* shape) const { - return previous_.Match(shape) && ShapeUtil::IndexIsValid(*shape, index_) && - subshape_.Match(&ShapeUtil::GetSubshape(*shape, index_)); + bool Match(const ::xla::Shape* shape, MatchOption option) const { + return ShapeUtil::IndexIsValid(*shape, index_) && + subshape_.Match(&ShapeUtil::GetSubshape(*shape, index_), option); } - bool Match(::xla::Shape* shape) const { - return previous_.Match(shape) && ShapeUtil::IndexIsValid(*shape, index_) && - subshape_.Match(ShapeUtil::GetMutableSubshape(shape, index_)); + bool Match(::xla::Shape* shape, MatchOption option) const { + return ShapeUtil::IndexIsValid(*shape, index_) && + subshape_.Match(ShapeUtil::GetMutableSubshape(shape, index_), + option); } private: - Previous previous_; ShapeIndexView index_; ShapePattern subshape_; }; @@ -431,14 +540,22 @@ class ShapePatternSubshapeImpl { // A pattern that matches Shapes. template class ShapePattern { + private: + template + ShapePattern> AppendImpl( + NewImpl new_impl) const { + return ShapePattern>( + AllOf(impl_, std::move(new_impl)), matched_shape_); + } + public: explicit constexpr ShapePattern(const Impl& impl, ShapeType** matched_shape) : impl_(impl), matched_shape_(matched_shape) {} // Returns true and captures the shape iff it matches the pattern. - bool Match(const ::xla::Shape* shape) const { - if (impl_.Match(shape)) { - if (matched_shape_) { + bool Match(const ::xla::Shape* shape, MatchOption option) const { + if (impl_.Match(shape, option)) { + if (option.capture && matched_shape_) { *matched_shape_ = shape; } return true; @@ -447,9 +564,9 @@ class ShapePattern { } // Returns true and captures the shape iff it matches the pattern. - bool Match(::xla::Shape* shape) const { - if (impl_.Match(shape)) { - if (matched_shape_) { + bool Match(::xla::Shape* shape, MatchOption option) const { + if (impl_.Match(shape, option)) { + if (option.capture && matched_shape_) { *matched_shape_ = shape; } return true; @@ -459,108 +576,90 @@ class ShapePattern { // Modifies the pattern to match only if the shape equals the given proto. // The layout must outlive the returned pattern. - constexpr ShapePattern> EqualTo( - const ::xla::Shape* shape) const { - return ShapePattern>( - ShapePatternEqualImpl(impl_, shape), matched_shape_); + constexpr auto EqualTo(const ::xla::Shape* shape) const + -> decltype(this->AppendImpl(ShapePatternEqualImpl(shape))) { + return AppendImpl(ShapePatternEqualImpl(shape)); } // Modifies the pattern to match only if the shape is compatible to the given // proto. The layout must outlive the returned pattern. - constexpr ShapePattern> - CompatibleTo(const ::xla::Shape* shape) const { - return ShapePattern>( - ShapePatternCompatibleImpl(impl_, shape), matched_shape_); + constexpr auto CompatibleTo(const ::xla::Shape* shape) const + -> decltype(this->AppendImpl(ShapePatternCompatibleImpl(shape))) { + return AppendImpl(ShapePatternCompatibleImpl(shape)); } // Modifies the pattern to match only if the shape has the given element type. - constexpr ShapePattern> - WithElementType(PrimitiveType element_type) const { - return ShapePattern>( - ShapePatternElementTypeImpl(impl_, element_type), matched_shape_); + constexpr auto WithElementType(PrimitiveType element_type) const + -> decltype(this->AppendImpl(ShapePatternElementTypeImpl(element_type))) { + return AppendImpl(ShapePatternElementTypeImpl(element_type)); } // Modifies the pattern to match only if the shape is scalar. - constexpr ShapePattern> IsScalar() - const { - return ShapePattern>( - ShapePatternIsScalarImpl(impl_), matched_shape_); + constexpr auto IsScalar() const + -> decltype(this->AppendImpl(ShapePatternIsScalarImpl())) { + return AppendImpl(ShapePatternIsScalarImpl()); } // Modifies the pattern to match only if the shape is an array. - constexpr ShapePattern> IsArray() - const { - return ShapePattern>( - ShapePatternIsArrayImpl(impl_), matched_shape_); + constexpr auto IsArray() const + -> decltype(this->AppendImpl(ShapePatternIsArrayImpl())) { + return AppendImpl(ShapePatternIsArrayImpl()); } // Modifies the pattern to match only if the shape is a tuple. - constexpr ShapePattern> IsTuple() - const { - return ShapePattern>( - ShapePatternIsTupleImpl(impl_), matched_shape_); + constexpr auto IsTuple() const + -> decltype(this->AppendImpl(ShapePatternIsTupleImpl())) { + return AppendImpl(ShapePatternIsTupleImpl()); } // Modifies the pattern to match only if the shape has the given rank. - constexpr ShapePattern> WithRank( - int64 rank) const { - return ShapePattern>( - ShapePatternRankImpl(impl_, rank), matched_shape_); + constexpr auto WithRank(int64 rank) const + -> decltype(this->AppendImpl(ShapePatternRankImpl(rank))) { + return AppendImpl(ShapePatternRankImpl(rank)); } // Modifies the pattern to match only if the shape has a layout that matches // the given pattern. template - constexpr ShapePattern> - WithLayout(const LayoutPattern& layout) const { - return ShapePattern>( - ShapePatternLayoutImpl(impl_, layout), - matched_shape_); - } - - constexpr ShapePattern< - ShapeType, - ShapePatternLayoutImpl>> - WithLayoutEqualTo(const ::xla::Layout* layout) const { + auto WithLayout(const LayoutPattern& layout) const + -> decltype(this->AppendImpl( + ShapePatternLayoutImpl(layout))) { + return AppendImpl(ShapePatternLayoutImpl(layout)); + } + + constexpr auto WithLayoutEqualTo(const ::xla::Layout* layout) const + -> decltype(this->WithLayout(Layout().EqualTo(layout))) { return WithLayout(Layout().EqualTo(layout)); } - constexpr ShapePattern< - ShapeType, - ShapePatternLayoutImpl>> - IsDenseArray() const { + constexpr auto IsDenseArray() const + -> decltype(this->WithLayout(Layout().WithDenseFormat())) { return WithLayout(Layout().WithDenseFormat()); } - constexpr ShapePattern< - ShapeType, - ShapePatternLayoutImpl>> - IsSparseArray() const { + constexpr auto IsSparseArray() const + -> decltype(this->WithLayout(Layout().WithSparseFormat())) { return WithLayout(Layout().WithSparseFormat()); } // Modifies the pattern to match only if the shape has a subshape that matches // the given pattern. template + auto WithSubshape(ShapeIndexView index, + const ShapePattern& subshape) + const -> decltype(this->AppendImpl( + ShapePatternSubshapeImpl(index, + subshape))) { + return AppendImpl( + ShapePatternSubshapeImpl(index, subshape)); + } + ShapePattern> - WithSubshape(ShapeIndexView index, - const ShapePattern& subshape) const { - return ShapePattern< - ShapeType, ShapePatternSubshapeImpl>( - ShapePatternSubshapeImpl(impl_, index, - subshape), - matched_shape_); - } - - ShapePattern>> + AllOfPattern>>> WithSubshapeEqualTo(ShapeIndexView index, const ::xla::Shape* shape) const { return WithSubshape(index, ShapePattern( @@ -568,9 +667,12 @@ class ShapePattern { .EqualTo(shape)); } - ShapePattern>> + ShapePattern>>> WithSubshapeCompatibleTo(ShapeIndexView index, const ::xla::Shape* shape) const { return WithSubshape(index, @@ -611,159 +713,169 @@ class HloInstructionPattern; // instruction is not nullptr. class HloInstructionPatternBaseImpl { public: - bool Match(const ::xla::HloInstruction* inst) const { + bool Match(const ::xla::HloInstruction* inst, MatchOption option) const { return inst != nullptr; } }; // An HloInstructionPattern implementation that matches only if the instruction // has a given name. -template class HloInstructionPatternNameImpl { public: - explicit HloInstructionPatternNameImpl(const Previous& previous, - absl::string_view name) - : previous_(previous), name_(name) {} + explicit HloInstructionPatternNameImpl(absl::string_view name) + : name_(name) {} - bool Match(const ::xla::HloInstruction* inst) const { - return previous_.Match(inst) && inst->name() == name_; + bool Match(const ::xla::HloInstruction* inst, MatchOption option) const { + return inst->name() == name_; } private: - Previous previous_; absl::string_view name_; }; // An HloInstructionPattern implementation that matches only if the instruction // has a given opcode. -template class HloInstructionPatternOpcodeImpl { public: - explicit constexpr HloInstructionPatternOpcodeImpl(const Previous& previous, - HloOpcode opcode, + explicit constexpr HloInstructionPatternOpcodeImpl(HloOpcode opcode, bool invert) - : previous_(previous), opcode_(opcode), invert_(invert) {} + : opcode_(opcode), invert_(invert) {} - bool Match(const ::xla::HloInstruction* inst) const { - return previous_.Match(inst) && (invert_ ^ (inst->opcode() == opcode_)); + bool Match(const ::xla::HloInstruction* inst, MatchOption option) const { + return (invert_ ^ (inst->opcode() == opcode_)); } private: - Previous previous_; HloOpcode opcode_; bool invert_; }; // An HloInstructionPattern implementation that matches only if the instruction // has a shape that matches a given pattern. -template +template class HloInstructionPatternShapeImpl { public: explicit constexpr HloInstructionPatternShapeImpl( - const Previous& previous, const ShapePattern& shape) - : previous_(previous), shape_(shape) {} + const ShapePattern& shape) + : shape_(shape) {} - bool Match(const ::xla::HloInstruction* inst) const { - return previous_.Match(inst) && shape_.Match(&inst->shape()); + bool Match(const ::xla::HloInstruction* inst, MatchOption option) const { + return shape_.Match(&inst->shape(), option); } - bool Match(::xla::HloInstruction* inst) const { - return previous_.Match(inst) && shape_.Match(inst->mutable_shape()); + bool Match(::xla::HloInstruction* inst, MatchOption option) const { + return shape_.Match(inst->mutable_shape(), option); } private: - Previous previous_; ShapePattern shape_; }; // An HloInstructionPattern implementation that matches only if the instruction // has an operand that matches a given pattern. -template +template class HloInstructionPatternOperandImpl { public: explicit constexpr HloInstructionPatternOperandImpl( - const Previous& previous, int64 operand_index, + int64 operand_index, const HloInstructionPattern& operand) - : previous_(previous), operand_index_(operand_index), operand_(operand) {} + : operand_index_(operand_index), operand_(operand) {} - bool Match(const ::xla::HloInstruction* inst) const { - return previous_.Match(inst) && operand_index_ < inst->operand_count() && - operand_.Match(inst->operand(operand_index_)); + bool Match(const ::xla::HloInstruction* inst, MatchOption option) const { + return operand_index_ < inst->operand_count() && + operand_.Match(inst->operand(operand_index_), option); } - bool Match(::xla::HloInstruction* inst) const { - return previous_.Match(inst) && operand_index_ < inst->operand_count() && - operand_.Match(inst->mutable_operand(operand_index_)); + bool Match(::xla::HloInstruction* inst, MatchOption option) const { + return operand_index_ < inst->operand_count() && + operand_.Match(inst->mutable_operand(operand_index_), option); } private: - Previous previous_; int64 operand_index_; HloInstructionPattern operand_; }; // An HloInstructionPattern implementation that matches only if the instruction // is a fusion node with a particular kind. -template class HloInstructionPatternFusionKindImpl { public: explicit constexpr HloInstructionPatternFusionKindImpl( - const Previous& previous, ::xla::HloInstruction::FusionKind kind) - : previous_(previous), kind_(kind) {} + ::xla::HloInstruction::FusionKind kind) + : kind_(kind) {} - bool Match(const ::xla::HloInstruction* inst) const { - return previous_.Match(inst) && inst->opcode() == HloOpcode::kFusion && - inst->fusion_kind() == kind_; + bool Match(const ::xla::HloInstruction* inst, MatchOption option) const { + return inst->opcode() == HloOpcode::kFusion && inst->fusion_kind() == kind_; } - bool Match(::xla::HloInstruction* inst) const { - return previous_.Match(inst) && inst->opcode() == HloOpcode::kFusion && - inst->fusion_kind() == kind_; + bool Match(::xla::HloInstruction* inst, MatchOption option) const { + return inst->opcode() == HloOpcode::kFusion && inst->fusion_kind() == kind_; } private: - Previous previous_; ::xla::HloInstruction::FusionKind kind_; }; // An HloInstructionPattern implementation that matches only if the instruction // is a kGetTupleElement with a particular tuple index. -template class HloInstructionPatternTupleIndexImpl { public: - explicit constexpr HloInstructionPatternTupleIndexImpl( - const Previous& previous, int64 tuple_index) - : previous_(previous), tuple_index_(tuple_index) {} + explicit constexpr HloInstructionPatternTupleIndexImpl(int64 tuple_index) + : tuple_index_(tuple_index) {} - bool Match(const ::xla::HloInstruction* inst) const { - return previous_.Match(inst) && - inst->opcode() == HloOpcode::kGetTupleElement && + bool Match(const ::xla::HloInstruction* inst, MatchOption option) const { + return inst->opcode() == HloOpcode::kGetTupleElement && inst->tuple_index() == tuple_index_; } - bool Match(::xla::HloInstruction* inst) const { - return previous_.Match(inst) && - inst->opcode() == HloOpcode::kGetTupleElement && + bool Match(::xla::HloInstruction* inst, MatchOption option) const { + return inst->opcode() == HloOpcode::kGetTupleElement && inst->tuple_index() == tuple_index_; } private: - Previous previous_; int64 tuple_index_; }; +template +class HloPredicatePatternImpl { + public: + explicit HloPredicatePatternImpl(Predicate pred) : pred_(std::move(pred)) {} + + bool Match(const ItemType* item, MatchOption option) const { + return pred_(item); + } + + bool Match(ItemType* item, MatchOption option) const { return pred_(item); } + + private: + Predicate pred_; +}; + +struct PatternFriend; + // A pattern that matches HloInstructions. template class HloInstructionPattern { + private: + template + HloInstructionPattern> + AppendImpl(NewImpl new_impl) const { + return HloInstructionPattern< + HloInstructionType, AllOfPattern<::xla::HloInstruction, Impl, NewImpl>>( + AllOf(impl_, std::move(new_impl)), matched_inst_); + } + public: explicit constexpr HloInstructionPattern(const Impl& impl, HloInstructionType** matched_inst) : impl_(impl), matched_inst_(matched_inst) {} // Returns true and captures the instruction iff it matches the pattern. - bool Match(const ::xla::HloInstruction* inst) const { - if (impl_.Match(inst)) { - if (matched_inst_) { + bool Match(const ::xla::HloInstruction* inst, MatchOption option) const { + if (impl_.Match(inst, option)) { + if (option.capture && matched_inst_) { *matched_inst_ = inst; } return true; @@ -772,9 +884,9 @@ class HloInstructionPattern { } // Returns true and captures the instruction iff it matches the pattern. - bool Match(::xla::HloInstruction* inst) const { - if (impl_.Match(inst)) { - if (matched_inst_) { + bool Match(::xla::HloInstruction* inst, MatchOption option) const { + if (impl_.Match(inst, option)) { + if (option.capture && matched_inst_) { *matched_inst_ = inst; } return true; @@ -783,102 +895,87 @@ class HloInstructionPattern { } // Modifies the pattern to match only if the instruction has the given name. - HloInstructionPattern> - WithName(absl::string_view name) const { - return HloInstructionPattern>( - HloInstructionPatternNameImpl(impl_, name), matched_inst_); + auto WithName(absl::string_view name) const + -> decltype(this->AppendImpl(HloInstructionPatternNameImpl(name))) { + return AppendImpl(HloInstructionPatternNameImpl(name)); } // Modifies the pattern to match only if the instruction has the given opcode. - constexpr HloInstructionPattern> - WithOpcode(HloOpcode opcode) const { - return HloInstructionPattern>( - HloInstructionPatternOpcodeImpl(impl_, opcode, false), - matched_inst_); + auto WithOpcode(HloOpcode opcode) const + -> decltype(this->AppendImpl(HloInstructionPatternOpcodeImpl(opcode, + false))) { + return AppendImpl(HloInstructionPatternOpcodeImpl(opcode, false)); } // Modifies the pattern to match only if the instruction does not have the // given opcode. - constexpr HloInstructionPattern> - WithoutOpcode(HloOpcode opcode) const { - return HloInstructionPattern>( - HloInstructionPatternOpcodeImpl(impl_, opcode, true), - matched_inst_); + auto WithoutOpcode(HloOpcode opcode) const + -> decltype(this->AppendImpl(HloInstructionPatternOpcodeImpl(opcode, + true))) { + return AppendImpl(HloInstructionPatternOpcodeImpl(opcode, true)); } // Modifies the pattern to match only if the instruction is a constant. - constexpr HloInstructionPattern> - IsConstant() const { + constexpr auto IsConstant() const + -> decltype(this->WithOpcode(HloOpcode::kConstant)) { return WithOpcode(HloOpcode::kConstant); } // Modifies the pattern to match only if the instruction is not a constant. - constexpr HloInstructionPattern> - IsNonConstant() const { + constexpr auto IsNonConstant() const + -> decltype(this->WithoutOpcode(HloOpcode::kConstant)) { return WithoutOpcode(HloOpcode::kConstant); } // Modifies the pattern to match only if the instruction has a shape that // matches the given pattern. template - constexpr HloInstructionPattern< - HloInstructionType, - HloInstructionPatternShapeImpl> - WithShape(const ShapePattern& shape) const { - return HloInstructionPattern< - HloInstructionType, - HloInstructionPatternShapeImpl>( - HloInstructionPatternShapeImpl(impl_, - shape), - matched_inst_); + constexpr auto WithShape(const ShapePattern& shape) + const -> decltype(this->AppendImpl( + HloInstructionPatternShapeImpl(shape))) { + return AppendImpl( + HloInstructionPatternShapeImpl(shape)); } // Modifies the pattern to match only if the instruction has an operand that // matches the given pattern. template - constexpr HloInstructionPattern< - HloInstructionType, - HloInstructionPatternOperandImpl> - WithOperand( + constexpr auto WithOperand( int64 operand_index, - const HloInstructionPattern& operand) const { - return HloInstructionPattern< - HloInstructionType, - HloInstructionPatternOperandImpl>( - HloInstructionPatternOperandImpl( - impl_, operand_index, operand), - matched_inst_); + const HloInstructionPattern& operand) const + -> decltype(this->AppendImpl( + HloInstructionPatternOperandImpl( + operand_index, operand))) { + return AppendImpl( + HloInstructionPatternOperandImpl( + operand_index, operand)); } // Modifies the pattern to match only if the instruction is a fusion node with // the given kind. - constexpr HloInstructionPattern> - WithFusionKind(HloInstruction::FusionKind kind) const { - return HloInstructionPattern>( - HloInstructionPatternFusionKindImpl(impl_, kind), matched_inst_); + constexpr auto WithFusionKind(HloInstruction::FusionKind kind) const + -> decltype(this->AppendImpl(HloInstructionPatternFusionKindImpl(kind))) { + return AppendImpl(HloInstructionPatternFusionKindImpl(kind)); } // Modifies the pattern to match only if the instruction is a // get-tuple-element with the given tuple index. - constexpr HloInstructionPattern> - WithTupleIndex(int64 tuple_index) const { - return HloInstructionPattern>( - HloInstructionPatternTupleIndexImpl(impl_, tuple_index), - matched_inst_); + constexpr auto WithTupleIndex(int64 tuple_index) const -> decltype( + this->AppendImpl(HloInstructionPatternTupleIndexImpl(tuple_index))) { + return AppendImpl(HloInstructionPatternTupleIndexImpl(tuple_index)); } private: + template + constexpr auto WithPredicate(Predicate pred) const -> decltype( + this->AppendImpl(HloPredicatePatternImpl( + std::move(pred)))) { + return AppendImpl( + HloPredicatePatternImpl(std::move(pred))); + } + + friend struct PatternFriend; + Impl impl_; HloInstructionType** matched_inst_; }; @@ -918,6 +1015,7 @@ Op(::xla::HloInstruction** matched_inst) { } XLA_NULLOP_PATTERN(Constant) XLA_NULLOP_PATTERN(Parameter) +XLA_NULLOP_PATTERN(Iota) #undef XLA_NULLOP_PATTERN // Helpers for unary instructions. @@ -1004,31 +1102,50 @@ XLA_UNOP_PATTERN(Transpose) .WithOperand(0, std::forward(lhs)) \ .WithOperand(1, std::forward(rhs)); \ } -XLA_BINOP_PATTERN(Add) + +#define XLA_COMMUTATIVE_BINOP_PATTERN(NAME) \ + XLA_BINOP_PATTERN(NAME) \ + \ + template \ + inline auto NAME##AnyOrder(Lhs&& lhs, Rhs&& rhs) \ + ->decltype(AnyOf(NAME(lhs, rhs), NAME(rhs, lhs))) { \ + return AnyOf(NAME(lhs, rhs), NAME(rhs, lhs)); \ + } \ + \ + template \ + inline auto NAME##AnyOrder(HloInstructionType** matched_inst, Lhs&& lhs, \ + Rhs&& rhs) \ + ->decltype(AnyOf(NAME(matched_inst, lhs, rhs), \ + NAME(matched_inst, rhs, lhs))) { \ + return AnyOf(NAME(matched_inst, lhs, rhs), \ + NAME(matched_inst, rhs, lhs)); \ + } +XLA_COMMUTATIVE_BINOP_PATTERN(Add) XLA_BINOP_PATTERN(Atan2) XLA_BINOP_PATTERN(Divide) XLA_BINOP_PATTERN(Complex) XLA_BINOP_PATTERN(Dot) -XLA_BINOP_PATTERN(Eq) +XLA_COMMUTATIVE_BINOP_PATTERN(Eq) XLA_BINOP_PATTERN(Gather) XLA_BINOP_PATTERN(Ge) XLA_BINOP_PATTERN(Gt) XLA_BINOP_PATTERN(Le) XLA_BINOP_PATTERN(Lt) -XLA_BINOP_PATTERN(Maximum) -XLA_BINOP_PATTERN(Minimum) -XLA_BINOP_PATTERN(Multiply) -XLA_BINOP_PATTERN(Ne) +XLA_COMMUTATIVE_BINOP_PATTERN(Maximum) +XLA_COMMUTATIVE_BINOP_PATTERN(Minimum) +XLA_COMMUTATIVE_BINOP_PATTERN(Multiply) +XLA_COMMUTATIVE_BINOP_PATTERN(Ne) XLA_BINOP_PATTERN(Outfeed) XLA_BINOP_PATTERN(Power) XLA_BINOP_PATTERN(Remainder) XLA_BINOP_PATTERN(Send) XLA_BINOP_PATTERN(Subtract) -XLA_BINOP_PATTERN(And) -XLA_BINOP_PATTERN(Or) +XLA_COMMUTATIVE_BINOP_PATTERN(And) +XLA_COMMUTATIVE_BINOP_PATTERN(Or) XLA_BINOP_PATTERN(ShiftLeft) XLA_BINOP_PATTERN(ShiftRightArithmetic) XLA_BINOP_PATTERN(ShiftRightLogical) +#undef XLA_COMMUTATIVE_BINOP_PATTERN #undef XLA_BINOP_PATTERN // Helpers for ternary instructions. @@ -1069,6 +1186,30 @@ XLA_TERNOP_PATTERN(Clamp); XLA_TERNOP_PATTERN(Select); #undef XLA_TERNOP_PATTERN +namespace detail { +struct PatternFriend { + template + static auto ConstantScalar(T constant) -> decltype( + Constant() + .WithShape(match::Shape().IsScalar()) + .WithPredicate( + std::declval>())) { + std::function pred = + [constant](const HloInstruction* instr) { + const auto& literal = Cast(instr)->literal(); + auto status_or_const = LiteralUtil::CreateR0(constant).Convert( + literal.shape().element_type()); + return status_or_const.ok() && + literal == status_or_const.ConsumeValueOrDie(); + }; + + return Constant() + .WithShape(match::Shape().IsScalar()) + .WithPredicate(std::move(pred)); + } +}; +} // namespace detail + // Helpers for matching non-constant instructions. inline auto NonConstant() -> decltype(Op().IsNonConstant()) { return Op().IsNonConstant(); @@ -1106,6 +1247,12 @@ inline auto GetTupleElement(HloInstructionType** matched_inst, Arg&& arg, .WithTupleIndex(tuple_index); } +template +inline auto ConstantScalar(T constant) + -> decltype(detail::PatternFriend::ConstantScalar(constant)) { + return detail::PatternFriend::ConstantScalar(constant); +} + } // namespace match } // namespace xla diff --git a/tensorflow/compiler/xla/service/pattern_matcher_test.cc b/tensorflow/compiler/xla/service/pattern_matcher_test.cc index a530581c34bf1d699eae3c53203c197f7943cc53..3ab7b7fd7168d7ddd1470fdb03a04ba7b171fddb 100644 --- a/tensorflow/compiler/xla/service/pattern_matcher_test.cc +++ b/tensorflow/compiler/xla/service/pattern_matcher_test.cc @@ -211,5 +211,188 @@ TEST(PatternMatcherTest, GetTupleElement) { EXPECT_TRUE(Match(root, match::GetTupleElement(match::Op(), 1))); } +TEST(PatternMatcherTest, AnyOf) { + constexpr char kModuleStr[] = R"( + HloModule test_module ENTRY test { ROOT constant = f16[] constant(1) })"; + TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr)); + auto* root = hlo_module->entry_computation()->root_instruction(); + + EXPECT_TRUE( + Match(root, match::AnyOf(match::ConstantScalar(0), + match::ConstantScalar(1)))); + EXPECT_TRUE( + Match(root, match::AnyOf(match::ConstantScalar(1), + match::ConstantScalar(0)))); + EXPECT_FALSE( + Match(root, match::AnyOf(match::ConstantScalar(0), + match::ConstantScalar(2)))); +} + +TEST(PatternMatcherTest, ConstantScalar) { + constexpr char kModuleStr[] = R"( + HloModule test_module ENTRY test { ROOT constant = f16[] constant(42) })"; + TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr)); + auto* root = hlo_module->entry_computation()->root_instruction(); + + EXPECT_TRUE(Match(root, match::ConstantScalar(42))); + EXPECT_FALSE(Match(root, match::ConstantScalar(41))); + EXPECT_FALSE(Match(root, match::ConstantScalar(0))); +} + +TEST(PatternMatcherTest, NoMatchConstantScalar) { + constexpr char kModuleStr[] = R"( + HloModule test_module ENTRY test { ROOT v = f16[] parameter(0) })"; + TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr)); + auto* root = hlo_module->entry_computation()->root_instruction(); + + EXPECT_FALSE(Match(root, match::ConstantScalar(42))); +} + +TEST(PatternMatcherTest, MultiplyAnyOrder) { + using match::ConstantScalar; + using match::MultiplyAnyOrder; + + constexpr char kModuleStr[] = R"( + HloModule test_module + ENTRY test { + lhs = f16[] constant(42) + rhs = f16[] constant(52) + ROOT multiply = f16[] multiply(lhs, rhs) + })"; + TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr)); + auto* root = hlo_module->entry_computation()->root_instruction(); + const HloInstruction* instr; + + EXPECT_TRUE(Match( + root, MultiplyAnyOrder(&instr, ConstantScalar(42), ConstantScalar(52)))); + EXPECT_TRUE(Match( + root, MultiplyAnyOrder(&instr, ConstantScalar(52), ConstantScalar(42)))); +} + +TEST(PatternMatcherTest, AnyOfShortCircuit) { + using match::AnyOf; + using match::Multiply; + using match::Op; + + constexpr char kModuleStr[] = R"( + HloModule test_module + ENTRY test { + lhs = f16[] constant(42) + rhs = f16[] constant(52) + ROOT multiply = f16[] multiply(lhs, rhs) + })"; + TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr)); + auto* root = hlo_module->entry_computation()->root_instruction(); + + { + const HloInstruction* mul = nullptr; + const HloInstruction* any = nullptr; + + ASSERT_TRUE(Match( + root, AnyOf(Multiply(&mul, Op(), Op()), Op(&any)))); + EXPECT_NE(nullptr, mul); + EXPECT_EQ(nullptr, any); + } + { + const HloInstruction* mul = nullptr; + const HloInstruction* any = nullptr; + + ASSERT_TRUE(Match( + root, AnyOf(Op(&any), Multiply(&mul, Op(), Op())))); + EXPECT_NE(nullptr, any); + EXPECT_EQ(nullptr, mul); + } +} + +TEST(PatternMatcherTest, AllOf) { + using match::AllOf; + using match::Broadcast; + using match::Constant; + using match::Op; + + constexpr char kModuleStr[] = R"( + HloModule test_module ENTRY test { ROOT constant = f16[] constant(1) })"; + TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr)); + auto* root = hlo_module->entry_computation()->root_instruction(); + + auto scalar_pattern = Constant().WithShape(match::Shape().IsScalar()); + auto f16_pattern = Constant().WithShape(match::Shape().WithElementType(F16)); + ASSERT_TRUE(Match(root, scalar_pattern)); + ASSERT_TRUE(Match(root, f16_pattern)); + EXPECT_TRUE(Match(root, AllOf(scalar_pattern, f16_pattern))); + EXPECT_TRUE(Match(root, AllOf(f16_pattern, scalar_pattern))); + EXPECT_FALSE( + Match(root, AllOf(Broadcast(Op()), f16_pattern))); + EXPECT_FALSE( + Match(root, AllOf(Broadcast(Op()), scalar_pattern))); +} + +TEST(PatternMatcherTest, AllOfNoCaptureIfNotMatch) { + using match::AllOf; + using match::Broadcast; + using match::Constant; + using match::Op; + + constexpr char kModuleStr[] = R"( + HloModule test_module + ENTRY test { + ROOT v = f16[] constant(42) + })"; + TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr)); + auto* root = hlo_module->entry_computation()->root_instruction(); + + const HloInstruction* constant = nullptr; + ASSERT_FALSE( + Match(root, AllOf(Constant(&constant), Broadcast(Op())))); + EXPECT_EQ(nullptr, constant); + ASSERT_TRUE(Match(root, Constant(&constant))); + EXPECT_NE(nullptr, constant); +} + +TEST(PatternMatcherTest, TestNoCapture) { + using match::Constant; + + constexpr char kModuleStr[] = R"( + HloModule test_module + ENTRY test { + ROOT v = f16[] constant(42) + })"; + TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr)); + auto* root = hlo_module->entry_computation()->root_instruction(); + + const HloInstruction* constant = nullptr; + ASSERT_TRUE(Match(root, Constant(&constant), {/*capture=*/false})); + EXPECT_EQ(nullptr, constant); +} + +TEST(PatternMatcherTest, TestCaptureMatchedSubPatternForAnyOf) { + using match::Add; + using match::AddAnyOrder; + using match::AnyOf; + using match::Op; + + constexpr char kModuleStr[] = R"( + HloModule test_module + ENTRY test { + u = f16[] parameter(0) + v = f16[] parameter(1) + ROOT add = f16[] add(u, v) + })"; + TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr)); + auto* root = hlo_module->entry_computation()->root_instruction(); + + const HloInstruction* addend0 = nullptr; + const HloInstruction* addend1 = nullptr; + const HloInstruction* addend2 = nullptr; + auto add2_pattern = Add(Op(&addend0), Op(&addend1)); + auto add3_pattern = AnyOf( + AddAnyOrder(add2_pattern, Op(&addend2)), add2_pattern, Op(&addend0)); + + ASSERT_TRUE(Match(root, add3_pattern)); + EXPECT_NE(nullptr, addend0); + EXPECT_NE(nullptr, addend1); + EXPECT_EQ(nullptr, addend2); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/platform_util.cc b/tensorflow/compiler/xla/service/platform_util.cc index 150af0cd9323479d2e7af1133184349e7bccd393..c522e7ae23b734090f85d241bf365fccc37f0adb 100644 --- a/tensorflow/compiler/xla/service/platform_util.cc +++ b/tensorflow/compiler/xla/service/platform_util.cc @@ -21,6 +21,7 @@ limitations under the License. #include "absl/strings/ascii.h" #include "absl/strings/str_join.h" +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" @@ -89,7 +90,11 @@ PlatformUtil::GetSupportedPlatforms() { if (platforms.empty()) { return NotFound("no platforms found"); } else if (platforms.size() == 1) { - return platforms[0]; + se::Platform* platform = platforms[0]; + if (!platform->Initialized()) { + TF_RETURN_IF_ERROR(platform->Initialize({})); + } + return platform; } // Multiple platforms present and we can't pick a reasonable default. @@ -98,23 +103,32 @@ PlatformUtil::GetSupportedPlatforms() { [](string* out, const se::Platform* p) { out->append(p->Name()); }); return InvalidArgument( "must specify platform because more than one platform found: %s", - platforms_string.c_str()); + platforms_string); } /* static */ StatusOr PlatformUtil::GetDefaultPlatform() { TF_ASSIGN_OR_RETURN(auto platforms, GetSupportedPlatforms()); + + se::Platform* platform = nullptr; if (platforms.empty()) { return NotFound("no platforms found"); } else if (platforms.size() == 1) { - return platforms[0]; + platform = platforms[0]; } else if (platforms.size() == 2) { for (int i = 0; i < 2; i++) { if (absl::AsciiStrToLower(platforms[i]->Name()) == kInterpreter && absl::AsciiStrToLower(platforms[1 - i]->Name()) != kInterpreter) { - return platforms[1 - i]; + platform = platforms[1 - i]; + break; } } } + if (platform != nullptr) { + if (!platform->Initialized()) { + TF_RETURN_IF_ERROR(platform->Initialize({})); + } + return platform; + } // Multiple platforms present and we can't pick a reasonable default. string platforms_string = absl::StrJoin( @@ -123,7 +137,7 @@ PlatformUtil::GetSupportedPlatforms() { return InvalidArgument( "must specify platform because more than one platform (except for the " "interpreter platform) found: %s", - platforms_string.c_str()); + platforms_string); } /*static*/ StatusOr PlatformUtil::GetPlatform( @@ -132,10 +146,13 @@ PlatformUtil::GetSupportedPlatforms() { TF_ASSIGN_OR_RETURN(auto platforms, PlatformUtil::GetSupportedPlatforms()); for (se::Platform* platform : platforms) { if (absl::AsciiStrToLower(platform->Name()) == platform_str) { + if (!platform->Initialized()) { + TF_RETURN_IF_ERROR(platform->Initialize({})); + } return platform; } } - return InvalidArgument("platform %s not found", platform_name.c_str()); + return InvalidArgument("platform %s not found", platform_name); } /*static*/ StatusOr PlatformUtil::GetPlatformExceptFor( @@ -151,17 +168,21 @@ PlatformUtil::GetSupportedPlatforms() { } if (matched.empty()) { return InvalidArgument("unable to find platform that is not %s", - platform_name.c_str()); + platform_name); } if (matched.size() == 1) { - return matched[0]; + auto platform = matched[0]; + if (!platform->Initialized()) { + TF_RETURN_IF_ERROR(platform->Initialize({})); + } + return platform; } string matched_string = absl::StrJoin( matched, ", ", [](string* out, const se::Platform* p) { out->append(p->Name()); }); return InvalidArgument( "found multiple platforms %s, but expected one platform except for %s", - matched_string.c_str(), platform_name.c_str()); + matched_string, platform_name); } // Returns whether the device underlying the given StreamExecutor is supported @@ -192,14 +213,17 @@ static bool IsDeviceSupported(se::StreamExecutor* executor) { PlatformUtil::GetStreamExecutors(se::Platform* platform) { int device_count = platform->VisibleDeviceCount(); if (device_count <= 0) { - return NotFound("no %s devices found", platform->Name().c_str()); + return NotFound("no %s devices found", platform->Name()); } if (platform->id() == se::host::kHostPlatformId) { // On host "devices", StreamExecutor exports a device for each hardware // thread. Because we parallelize a single computation across threads, it - // doesn't make sense to expose these as separate devices, so fix the number - // of devices to one. - device_count = 1; + // doesn't make sense to expose these as separate devices, so by default we + // fix the number of devices to one. However we do let the user override + // this behavior to help run tests on the host that run models in parallel + // across multiple devices. + device_count = legacy_flags::GetDebugOptionsFromFlags() + .xla_force_host_platform_device_count(); } std::vector stream_executors(device_count, nullptr); VLOG(1) << "Initializing devices"; @@ -231,7 +255,7 @@ PlatformUtil::GetStreamExecutors(se::Platform* platform) { if (std::all_of(stream_executors.begin(), stream_executors.end(), [](se::StreamExecutor* s) { return s == nullptr; })) { return InternalError("no supported devices found for platform %s", - platform->Name().c_str()); + platform->Name()); } return stream_executors; } diff --git a/tensorflow/compiler/xla/service/reduce_precision_insertion.h b/tensorflow/compiler/xla/service/reduce_precision_insertion.h index 256b231e3af43a2ee85c97a5efab1f022d4de4b1..0b4e82e8d606cf2cacfab42d07c2201939d5e10b 100644 --- a/tensorflow/compiler/xla/service/reduce_precision_insertion.h +++ b/tensorflow/compiler/xla/service/reduce_precision_insertion.h @@ -22,14 +22,13 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h" -#include "tensorflow/core/lib/gtl/flatmap.h" namespace xla { // HLO pass which inserts reduce-precision instructions into the HLO graph, for // purposes of experimenting with the effects of reduced-precision storage of // intermediate values. -class ReducePrecisionInsertion : public HloPassInterface { +class ReducePrecisionInsertion : public HloModulePass { using InstructionFilterFunction = std::function; public: diff --git a/tensorflow/compiler/xla/service/reshape_mover.h b/tensorflow/compiler/xla/service/reshape_mover.h index 1e86a0823a56a9e52421a5c8bd49e0adb98a2c70..a3db439e34000ef3fcf4b190cb372947e285a64e 100644 --- a/tensorflow/compiler/xla/service/reshape_mover.h +++ b/tensorflow/compiler/xla/service/reshape_mover.h @@ -24,7 +24,7 @@ namespace xla { // This now only moves them outputward across elementwise ops all whose operands // are equivalent Reshapes or Transposes, but in future could potentially move // them inputward also. -class ReshapeMover : public HloPassInterface { +class ReshapeMover : public HloModulePass { public: absl::string_view name() const override { return "reshape-mover"; } diff --git a/tensorflow/compiler/xla/service/reshape_mover_test.cc b/tensorflow/compiler/xla/service/reshape_mover_test.cc index a395dd5333f9b6b5f71a561b52cd9312a3faef2d..fcf269eee925c2ddb7511d70e71bd815e4b8c24a 100644 --- a/tensorflow/compiler/xla/service/reshape_mover_test.cc +++ b/tensorflow/compiler/xla/service/reshape_mover_test.cc @@ -34,12 +34,7 @@ namespace { namespace op = xla::testing::opcode_matchers; -class ReshapeMoverTest : public HloVerifiedTestBase { - public: - ReshapeMoverTest() - : HloVerifiedTestBase(/*layout_sensitive=*/false, - /*allow_mixed_precision=*/false) {} -}; +class ReshapeMoverTest : public HloVerifiedTestBase {}; TEST_F(ReshapeMoverTest, ReshapesWithDifferentInputShapesNotMoved) { HloComputation::Builder builder(TestName()); diff --git a/tensorflow/compiler/xla/service/scatter_expander.cc b/tensorflow/compiler/xla/service/scatter_expander.cc index 338f0c09e9e7f59127023144ff30ac62aff55ee1..de7aee262e61195b37099fc661a95508d0539e18 100644 --- a/tensorflow/compiler/xla/service/scatter_expander.cc +++ b/tensorflow/compiler/xla/service/scatter_expander.cc @@ -26,7 +26,6 @@ limitations under the License. namespace xla { -using tensorflow::gtl::ArraySlice; // Transposes the given scatter_indices such that the index_vector_dim becomes // the most-minor dimension. @@ -87,7 +86,7 @@ static StatusOr CanonicalizeScatterIndices( // major dimensions and all the window dimensions appear in the minor // dimensions. static StatusOr PermuteScatterAndWindowDims( - HloInstruction* updates, ArraySlice update_window_dims) { + HloInstruction* updates, absl::Span update_window_dims) { std::vector permutation; const int64 updates_rank = ShapeUtil::Rank(updates->shape()); permutation.reserve(updates_rank); @@ -156,6 +155,53 @@ static StatusOr ExpandIndexVectorIntoOperandSpace( return MakeConcatHlo(expanded_index_components, /*dimension=*/0); } +static StatusOr CheckIndexValidity( + HloComputation* computation, HloInstruction* index, + absl::Span operand_dims, absl::Span window_sizes, + HloModule* module) { + DCHECK_NE(nullptr, module); + DCHECK_EQ(operand_dims.size(), window_sizes.size()); + + // Valid range for the index: [0, operand_dims - window_sizes] + + // Check if the index has any negative values. + TF_ASSIGN_OR_RETURN( + HloInstruction * zero_index, + BroadcastZeros(computation, index->shape().element_type(), + AsInt64Slice(index->shape().dimensions()))); + TF_ASSIGN_OR_RETURN(HloInstruction * negative_index_check, + MakeBinaryHlo(HloOpcode::kLe, zero_index, index)); + + // Check if the index is OOB w.r.t. the operand dimensions and window sizes. + std::vector max_valid_index(operand_dims.size()); + for (int i = 0; i < operand_dims.size(); ++i) { + max_valid_index[i] = operand_dims[i] - window_sizes[i]; + } + TF_ASSIGN_OR_RETURN( + HloInstruction * max_valid_index_constant, + MakeR1ConstantHlo(computation, index->shape().element_type(), + max_valid_index)); + TF_ASSIGN_OR_RETURN( + HloInstruction * oob_index_check, + MakeBinaryHlo(HloOpcode::kGe, max_valid_index_constant, index)); + + // Combine the results of the two checks above. + TF_ASSIGN_OR_RETURN( + HloInstruction * valid_index, + MakeBinaryHlo(HloOpcode::kAnd, negative_index_check, oob_index_check)); + + // Reduce the index validity check vector into a scalar predicate. + auto reduction_init = computation->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); + TF_ASSIGN_OR_RETURN( + HloInstruction * valid_index_reduced, + MakeReduceHlo(valid_index, reduction_init, HloOpcode::kAnd, module)); + + // Return a broadcasted value of the scalar predicate to the same size as the + // window. + return MakeBroadcastHlo(valid_index_reduced, {}, window_sizes); +} + // Body of the while loop that performs the scatter operation using other HLOs. static StatusOr> ScatterLoopBody( HloInstruction* scatter, HloInstruction* induction_var, @@ -223,7 +269,16 @@ static StatusOr> ScatterLoopBody( InsertDegenerateDims(update_slice_for_scatter, AsInt64Slice(dim_numbers.inserted_window_dims()))); - // Extact the slice to update from `operand` tensor. + // Note that the following transformation assumes that both DynamicSlice and + // DynamicUpdateSlice follow the same semantics for OOB indices. For example, + // if there are negative indices and DynamicSlice uses "clamping" semantics, + // then the extracted data will be "shifted". Since DynamicUpdateSlice also + // follows the same "clamping" semantics, writing the update will also be + // "shifted" by exactly the same amount. So, this transformation is correct as + // long as the semantics of handling OOB indices remain the same in + // DynamicSlice and DynamicUpdateSlice. + + // Extract the slice to update from `operand` tensor. const Shape& update_slice_shape = update_slice_with_dims_inserted->shape(); TF_ASSIGN_OR_RETURN( HloInstruction * operand_slice_to_update, @@ -238,10 +293,24 @@ static StatusOr> ScatterLoopBody( MakeMapHlo({operand_slice_to_update, update_slice_with_dims_inserted}, scatter->to_apply())); + TF_ASSIGN_OR_RETURN( + HloInstruction * is_index_valid, + CheckIndexValidity( + operand->parent(), scatter_slice_start, + AsInt64Slice(operand->shape().dimensions()), + AsInt64Slice(update_slice_with_dims_inserted->shape().dimensions()), + scatter->GetModule())); + + // Select the updated operand only if the index is valid. If not, select the + // original value. + TF_ASSIGN_OR_RETURN(HloInstruction * update_to_apply, + MakeSelectHlo(is_index_valid, updated_operand_slice, + operand_slice_to_update)); + // Write the updated value of the slice into `operand` tensor. - TF_ASSIGN_OR_RETURN(HloInstruction * updated_operand, - MakeDynamicUpdateSliceHlo(operand, updated_operand_slice, - scatter_slice_start)); + TF_ASSIGN_OR_RETURN( + HloInstruction * updated_operand, + MakeDynamicUpdateSliceHlo(operand, update_to_apply, scatter_slice_start)); return StatusOr>{ {updated_operand, scatter_indices, updates}}; @@ -291,7 +360,7 @@ StatusOr ScatterExpander::ExpandScatter( return Unimplemented( "Scatter operations with more than 2147483647 scatter indices are not " "supported. This error occurred for %s.", - scatter->ToString().c_str()); + scatter->ToString()); } // Canonicalize the scatter_indices, after which the size of its most-major diff --git a/tensorflow/compiler/xla/service/scatter_expander.h b/tensorflow/compiler/xla/service/scatter_expander.h index 14f062c89cfd4657097c1a933621a3e945f89c53..559a85dccfef27816e7dbf746fd71c44bbf46f60 100644 --- a/tensorflow/compiler/xla/service/scatter_expander.h +++ b/tensorflow/compiler/xla/service/scatter_expander.h @@ -20,7 +20,7 @@ limitations under the License. namespace xla { -class ScatterExpander : public HloPassInterface { +class ScatterExpander : public HloModulePass { public: absl::string_view name() const override { return "scatter_expander"; } StatusOr Run(HloModule* module) override; diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index d39a5191b8e8fb9a420adfade73fbedea998d2bb..75465359f8f37e56369c0976ba7434e3c3f202cc 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -22,6 +22,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/execution_options_util.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" @@ -47,7 +48,6 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/gtl/cleanup.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" @@ -55,24 +55,22 @@ limitations under the License. #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/ptr_util.h" -using absl::StrCat; -using ::tensorflow::strings::Printf; - namespace xla { - namespace { +using absl::StrCat; +using absl::StrFormat; + // Records the arguments used to invoke a computation in an HloSnapshot proto. -Status RecordArguments( - const tensorflow::gtl::ArraySlice arguments, - se::Stream* stream, TransferManager* transfer_manager, - HloSnapshot* module) { +Status RecordArguments(const absl::Span arguments, + se::Stream* stream, TransferManager* transfer_manager, + HloSnapshot* module) { module->clear_arguments(); for (const ShapedBuffer* argument : arguments) { TF_ASSIGN_OR_RETURN( - std::unique_ptr literal, + Literal literal, transfer_manager->TransferLiteralFromDevice(stream, *argument)); - *module->add_arguments() = literal->ToProto(); + *module->add_arguments() = literal.ToProto(); } return Status::OK(); } @@ -82,9 +80,9 @@ Status RecordResult(const ShapedBuffer& result, se::Stream* stream, TransferManager* transfer_manager, HloSnapshot* module) { module->clear_result(); TF_ASSIGN_OR_RETURN( - std::unique_ptr literal, + Literal literal, transfer_manager->TransferLiteralFromDevice(stream, result)); - *module->mutable_result() = literal->ToProto(); + *module->mutable_result() = literal.ToProto(); return Status::OK(); } @@ -148,19 +146,19 @@ Service::Service(const ServiceOptions& options, CHECK_GE(execute_backend_->device_count(), options_.number_of_replicas()) << "Requested more replicas than there are devices."; } - LOG(INFO) << Printf( + LOG(INFO) << StrFormat( "XLA service %p executing computations on platform %s. Devices:", this, - execute_backend_->platform()->Name().c_str()); + execute_backend_->platform()->Name()); for (int i = 0; i < execute_backend_->device_count(); ++i) { if (execute_backend_->device_ordinal_supported(i)) { se::StreamExecutor* executor = execute_backend_->stream_executor(i).ValueOrDie(); const auto& description = executor->GetDeviceDescription(); - LOG(INFO) << Printf(" StreamExecutor device (%d): %s, %s", i, - description.name().c_str(), - description.platform_version().c_str()); + LOG(INFO) << StrFormat(" StreamExecutor device (%d): %s, %s", i, + description.name(), + description.platform_version()); } else { - LOG(INFO) << Printf(" StreamExecutor device (%d) not supported", i); + LOG(INFO) << StrFormat(" StreamExecutor device (%d) not supported", i); } } } else { @@ -200,16 +198,16 @@ Status Service::ValidateResultShape(const Shape& client_shape, return InvalidArgument( "Shape used to set computation result layout %s is not compatible " "with result shape %s", - ShapeUtil::HumanStringWithLayout(client_shape).c_str(), - ShapeUtil::HumanString(result_shape).c_str()); + ShapeUtil::HumanStringWithLayout(client_shape), + ShapeUtil::HumanString(result_shape)); } return Status::OK(); } StatusOr>> Service::ResolveAndValidateArguments( - tensorflow::gtl::ArraySlice arguments, - tensorflow::gtl::ArraySlice stream_executors) { + absl::Span arguments, + absl::Span stream_executors) const { CHECK_EQ(options_.number_of_replicas(), stream_executors.size()); std::vector> replicated_arguments; replicated_arguments.resize(options_.number_of_replicas()); @@ -231,9 +229,9 @@ Service::ResolveAndValidateArguments( return InvalidArgument( "argument %lu is on device %s:%d but computation will be executed " "on device %s", - i, shaped_buffer->platform()->Name().c_str(), + i, shaped_buffer->platform()->Name(), shaped_buffer->device_ordinal(), - execute_backend_->device_name(replica_device_ordinal).c_str()); + execute_backend_->device_name(replica_device_ordinal)); } replicated_arguments[replica].push_back(shaped_buffer); } @@ -243,13 +241,13 @@ Service::ResolveAndValidateArguments( StatusOr> Service::CreateModuleConfig( const ProgramShape& program_shape, - tensorflow::gtl::ArraySlice argument_shapes, + absl::Span argument_shapes, const ExecutionOptions* execution_options) { auto config = absl::make_unique(program_shape); ComputationLayout* computation_layout = config->mutable_entry_computation_layout(); if (program_shape.parameters_size() != argument_shapes.size()) { - return InvalidArgument("computation takes %d parameters, but %zu given", + return InvalidArgument("computation takes %d parameters, but %u given", program_shape.parameters_size(), argument_shapes.size()); } @@ -261,8 +259,8 @@ StatusOr> Service::CreateModuleConfig( return InvalidArgument( "Argument does not match shape of computation parameter %d: want " "%s, got %s", - i, ShapeUtil::HumanString(program_shape.parameters(i)).c_str(), - ShapeUtil::HumanString(*argument_shapes[i]).c_str()); + i, ShapeUtil::HumanString(program_shape.parameters(i)), + ShapeUtil::HumanString(*argument_shapes[i])); } TF_RETURN_IF_ERROR( computation_layout->mutable_parameter_layout(i)->CopyLayoutFromShape( @@ -300,7 +298,7 @@ StatusOr> Service::CreateModuleConfig( StatusOr> Service::CreateModuleConfig( const ProgramShape& program_shape, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, const ExecutionOptions& execution_options) { std::vector argument_shapes; for (const auto* arg : arguments) { @@ -314,7 +312,7 @@ StatusOr>> Service::BuildExecutables( std::vector> module_configs, Backend* backend, std::vector> executors, DeviceMemoryAllocator* device_allocator) { - VLOG(1) << Printf("BuildExecutable on service %p", this); + VLOG(1) << StrFormat("BuildExecutable on service %p", this); // Dump computation proto state if flag is set. std::vector> hlo_snapshots; @@ -329,9 +327,8 @@ StatusOr>> Service::BuildExecutables( auto hlo_snapshot = absl::make_unique(); *hlo_snapshot->mutable_hlo()->mutable_hlo_module() = *module_protos[i]; if (!directory_path.empty()) { - string filename = - Printf("computation_%lld__%s", module_protos[i]->id(), - module_protos[i]->entry_computation_name().c_str()); + string filename = StrFormat("computation_%d__%s", module_protos[i]->id(), + module_protos[i]->entry_computation_name()); TF_RETURN_IF_ERROR( Executable::DumpToDirectory(directory_path, filename, *hlo_snapshot)); } @@ -344,19 +341,19 @@ StatusOr>> Service::BuildExecutables( } CHECK_EQ(module_protos.size(), module_configs.size()); - std::vector> modules; + auto module_group = + absl::make_unique(module_protos[0]->name()); for (int64 i = 0; i < module_protos.size(); ++i) { const HloModuleProto* proto = module_protos[i]; const HloModuleConfig& config = *module_configs[i]; - TF_ASSIGN_OR_RETURN(auto module, - HloModule::CreateFromProto(*proto, config)); - modules.push_back(std::move(module)); + TF_ASSIGN_OR_RETURN(auto module, CreateModuleFromProto(*proto, config)); + module_group->push_back(std::move(module)); } TF_ASSIGN_OR_RETURN( std::vector> executables, - backend->compiler()->Compile(std::move(modules), std::move(executors), - device_allocator)); + backend->compiler()->Compile(std::move(module_group), + std::move(executors), device_allocator)); for (size_t i = 0; i < module_protos.size(); ++i) { if (!module_configs[i]->debug_options().xla_dump_executions_to().empty()) { @@ -369,12 +366,10 @@ StatusOr>> Service::BuildExecutables( StatusOr> Service::ExecuteParallelAndRegisterResult( - tensorflow::gtl::ArraySlice executables, - tensorflow::gtl::ArraySlice>> - arguments, - Backend* backend, tensorflow::gtl::ArraySlice device_handles, - tensorflow::gtl::ArraySlice result_tags, - ExecutionProfile* profile) { + absl::Span executables, + absl::Span>> arguments, + Backend* backend, absl::Span device_handles, + absl::Span result_tags, ExecutionProfile* profile) { // Streams where the computation are launched, so we can wait on the streams // to complete. std::vector streams; @@ -454,8 +449,8 @@ Service::ExecuteParallelAndRegisterResult( for (int64 i = 0; i < streams.size(); ++i) { Status block_status = streams[i]->BlockHostUntilDone(); if (!block_status.ok()) { - return InternalError("failed to complete execution for stream %lld: %s", - i, block_status.error_message().c_str()); + return InternalError("failed to complete execution for stream %d: %s", i, + block_status.error_message()); } } @@ -513,8 +508,7 @@ Service::ExecuteParallelAndRegisterResult( StatusOr Service::ExecuteAndRegisterResult( Executable* executable, - const tensorflow::gtl::ArraySlice> - arguments, + const absl::Span> arguments, Backend* backend, const string& result_tag, ExecutionProfile* profile) { // Set up streams. std::vector streams; @@ -557,8 +551,7 @@ StatusOr Service::ExecuteAndRegisterResult( // TODO(b/69985541): Support profiling also on this path. - std::vector> - replicated_arguments; + std::vector> replicated_arguments; for (const auto& arg : arguments) { replicated_arguments.push_back(arg); } @@ -580,7 +573,7 @@ StatusOr> Service::GetExecutors( if (requests_size > 1 && execution_options.device_handles_size() > 1) { return InvalidArgument( "Parallel requests with multiple device handles is not supported. " - "Found %lld parallel requests, with request %lld containing %d device " + "Found %d parallel requests, with request %d containing %d device " "handles.", requests_size, request_index, execution_options.device_handles_size()); } @@ -597,7 +590,7 @@ StatusOr> Service::GetExecutors( StatusOr>> Service::GetArguments( const ExecutionOptions& execution_options, - tensorflow::gtl::ArraySlice arguments) { + absl::Span arguments) const { // Resolve the allocations for the arguments of the computation, and create // a vector of device memory offsets for the arguments from the allocations. // In the case of partitioned computations, assume all arguments go on the @@ -641,7 +634,7 @@ Status Service::ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg, arg->requests(i).execution_options(); const ExecuteGraphRequest& request = arg->requests(i); TF_RET_CHECK(request.has_computation()) << "computations may not be empty"; - TF_RET_CHECK(request.computation().has_program_shape()) + TF_RET_CHECK(request.computation().has_host_program_shape()) << "programe shape may not be empty"; // Get the executors. @@ -658,7 +651,7 @@ Status Service::ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg, // replica 0. TF_ASSIGN_OR_RETURN( std::unique_ptr module_config, - CreateModuleConfig(request.computation().program_shape(), + CreateModuleConfig(request.computation().host_program_shape(), replicated_arguments.front(), request.execution_options())); VLOG(3) @@ -745,8 +738,8 @@ Status Service::GetDeviceHandles(const GetDeviceHandlesRequest* arg, } if (available_device_count < arg->device_count() * replica_count) { return ResourceExhausted( - "Requested device count (%lld) exceeds the number of available devices " - "on the target (%lld)", + "Requested device count (%d) exceeds the number of available devices " + "on the target (%d)", arg->device_count(), available_device_count); } @@ -796,9 +789,9 @@ StatusOr> Service::BuildExecutable( const HloModuleProto& module_proto, std::unique_ptr module_config, Backend* backend, se::StreamExecutor* executor, DeviceMemoryAllocator* device_allocator) { - VLOG(1) << Printf( + VLOG(1) << StrFormat( "BuildExecutable on service %p with serialized module proto: %s", this, - module_proto.name().c_str()); + module_proto.name()); // Dump computation proto state if flag is set. auto hlo_snapshot = absl::make_unique(); @@ -809,17 +802,17 @@ StatusOr> Service::BuildExecutable( if (!directory_path.empty() || !execution_directory_path.empty()) { *hlo_snapshot->mutable_hlo()->mutable_hlo_module() = module_proto; if (!directory_path.empty()) { - string filename = Printf("computation_%lld__%s", module_proto.id(), - module_proto.entry_computation_name().c_str()); + string filename = StrFormat("computation_%d__%s", module_proto.id(), + module_proto.entry_computation_name()); TF_RETURN_IF_ERROR( Executable::DumpToDirectory(directory_path, filename, *hlo_snapshot)); } } TF_ASSIGN_OR_RETURN(std::unique_ptr module, - HloModule::CreateFromProto(module_proto, *module_config)); + CreateModuleFromProto(module_proto, *module_config)); - TF_RETURN_IF_ERROR(MaybeDumpHloModule(*module)); + TF_RETURN_IF_ERROR(MaybeDumpUnoptimizedHloModule(*module)); TF_ASSIGN_OR_RETURN( module, backend->compiler()->RunHloPasses(std::move(module), executor, @@ -843,7 +836,7 @@ Status Service::ExecuteGraph(const ExecuteGraphRequest* arg, if (!arg->has_computation()) { return InvalidArgument("computations may not be empty"); } - if (!arg->computation().has_program_shape()) { + if (!arg->computation().has_host_program_shape()) { return InvalidArgument("programe shape may not be empty"); } @@ -858,10 +851,11 @@ Status Service::ExecuteGraph(const ExecuteGraphRequest* arg, std::vector> replicated_arguments, ResolveAndValidateArguments(arg->arguments(), replicas)); - TF_ASSIGN_OR_RETURN(std::unique_ptr module_config, - CreateModuleConfig(arg->computation().program_shape(), - replicated_arguments.front(), - arg->execution_options())); + TF_ASSIGN_OR_RETURN( + std::unique_ptr module_config, + CreateModuleConfig(arg->computation().host_program_shape(), + replicated_arguments.front(), + arg->execution_options())); TF_ASSIGN_OR_RETURN( std::unique_ptr executable, @@ -935,16 +929,15 @@ Status Service::TransferToClient(const TransferToClientRequest* arg, shaped_buffer->device_ordinal())); TF_ASSIGN_OR_RETURN( - std::unique_ptr result_literal, + Literal result_literal, execute_backend_->transfer_manager()->TransferLiteralFromDevice( stream.get(), *shaped_buffer)); - if (LayoutUtil::LayoutsInShapesEqual(*return_shape, - result_literal->shape())) { - *result->mutable_literal() = result_literal->ToProto(); + if (LayoutUtil::LayoutsInShapesEqual(*return_shape, result_literal.shape())) { + *result->mutable_literal() = result_literal.ToProto(); } else { *result->mutable_literal() = - result_literal->Relayout(*return_shape)->ToProto(); + result_literal.Relayout(*return_shape).ToProto(); } return Status::OK(); } @@ -966,9 +959,9 @@ std::unique_ptr CloneShapedBufferOnDevice( Status Service::TransferToServer(const TransferToServerRequest* arg, TransferToServerResponse* result) { - TF_ASSIGN_OR_RETURN(std::unique_ptr literal, + TF_ASSIGN_OR_RETURN(Literal literal, Literal::CreateFromProto(arg->literal())); - const Shape& shape = literal->shape(); + const Shape& shape = literal.shape(); std::vector replicas; if (arg->has_device_handle()) { @@ -990,7 +983,7 @@ Status Service::TransferToServer(const TransferToServerRequest* arg, TF_ASSIGN_OR_RETURN(auto stream, execute_backend_->BorrowStream(executor)); TF_RETURN_IF_ERROR( execute_backend_->transfer_manager()->TransferLiteralToDevice( - stream.get(), *literal, shaped_buffer)); + stream.get(), literal, shaped_buffer)); replicated_buffers.emplace_back(std::move(shaped_buffer)); } TF_ASSIGN_OR_RETURN(*result->mutable_data(), @@ -1010,8 +1003,7 @@ Status Service::TransferToInfeed(const TransferToInfeedRequest* arg, "%s", StrCat("The replica_id=", arg->replica_id(), " on TransferToInfeedRequest not in range [0, replica_count=", - replica_count, ").") - .c_str()); + replica_count, ").")); } se::StreamExecutor* executor; @@ -1026,10 +1018,10 @@ Status Service::TransferToInfeed(const TransferToInfeedRequest* arg, executor = replicas[arg->replica_id()]; } - TF_ASSIGN_OR_RETURN(std::unique_ptr literal, + TF_ASSIGN_OR_RETURN(Literal literal, Literal::CreateFromProto(arg->literal())); - return execute_backend_->transfer_manager()->TransferLiteralToInfeed( - executor, *literal); + return execute_backend_->transfer_manager()->TransferLiteralToInfeed(executor, + literal); } Status Service::TransferFromOutfeed(const TransferFromOutfeedRequest* arg, @@ -1037,8 +1029,7 @@ Status Service::TransferFromOutfeed(const TransferFromOutfeedRequest* arg, const int64 replica_count = options_.number_of_replicas(); if (arg->replica_id() < 0 || arg->replica_id() >= replica_count) { return FailedPrecondition( - "The replica_id=%lld on TransferFromOutfeedRequest not in range [0, " - "%lld)", + "The replica_id=%d on TransferFromOutfeedRequest not in range [0, %d)", arg->replica_id(), replica_count); } @@ -1058,8 +1049,8 @@ Status Service::TransferFromOutfeed(const TransferFromOutfeedRequest* arg, TF_RETURN_IF_ERROR( execute_backend_->transfer_manager()->TransferLiteralFromOutfeed( - executor, arg->shape_with_layout(), *literal)); - *result->mutable_literal() = literal->ToProto(); + executor, arg->shape_with_layout(), literal)); + *result->mutable_literal() = literal.ToProto(); return Status::OK(); } @@ -1073,15 +1064,15 @@ Status Service::ComputeConstantGraph(const ComputeConstantGraphRequest* arg, if (!arg->has_computation()) { return InvalidArgument("computations may not be empty"); } - if (!arg->computation().has_program_shape()) { + if (!arg->computation().has_host_program_shape()) { return InvalidArgument("program shape may not be empty"); } - if (arg->computation().program_shape().parameters_size() != 0) { + if (arg->computation().host_program_shape().parameters_size() != 0) { return InvalidArgument( "constant computation may not depend on any parameters."); } - ProgramShape program_shape = arg->computation().program_shape(); + ProgramShape program_shape = arg->computation().host_program_shape(); TF_DCHECK_OK(ShapeUtil::ValidateShape(program_shape.result())); if (arg->has_output_layout()) { TF_RETURN_IF_ERROR(LayoutUtil::ValidateLayoutForShape( @@ -1091,21 +1082,20 @@ Status Service::ComputeConstantGraph(const ComputeConstantGraphRequest* arg, HloModuleConfig config(program_shape); TF_ASSIGN_OR_RETURN(std::unique_ptr module, - HloModule::CreateFromProto(arg->computation(), config)); + CreateModuleFromProto(arg->computation(), config)); HloEvaluator evaluator; - TF_ASSIGN_OR_RETURN(auto result_literal, - evaluator.Evaluate>( - *module, /*arg_literals=*/{})); + TF_ASSIGN_OR_RETURN(auto result_literal, evaluator.Evaluate( + *module, /*arg_literals=*/{})); // Since the result layout is non-effective to the Evaluator results, explicit // relayout here. // // TODO(b/77824332): Make HloEvaluator take care of the re-layout. if (arg->has_output_layout()) { - result_literal = result_literal->Relayout(arg->output_layout()); + result_literal = result_literal.Relayout(arg->output_layout()); } - *result->mutable_literal() = result_literal->ToProto(); + *result->mutable_literal() = result_literal.ToProto(); return Status::OK(); } @@ -1122,14 +1112,14 @@ Status Service::GetComputationGraphStats( if (!arg->has_computation()) { return InvalidArgument("Computations may not be empty."); } - if (!arg->computation().has_program_shape()) { + if (!arg->computation().has_host_program_shape()) { return InvalidArgument("Program shape may not be empty."); } - HloModuleConfig config(arg->computation().program_shape()); + HloModuleConfig config(arg->computation().host_program_shape()); config.set_debug_options(arg->debug_options()); TF_ASSIGN_OR_RETURN(std::unique_ptr module, - HloModule::CreateFromProto(arg->computation(), config)); + CreateModuleFromProto(arg->computation(), config)); hlo_graph_dumper::MaybeDumpHloModule(*module, "computation statistics subject"); @@ -1171,7 +1161,7 @@ StatusOr> Service::Replicas( return replicas; } -Status Service::MaybeDumpHloModule(const HloModule& module) const { +Status Service::MaybeDumpUnoptimizedHloModule(const HloModule& module) const { const string xla_dump_unoptimized_hlo_proto_to = module.config().debug_options().xla_dump_unoptimized_hlo_proto_to(); if (xla_dump_unoptimized_hlo_proto_to.empty()) { @@ -1179,7 +1169,8 @@ Status Service::MaybeDumpHloModule(const HloModule& module) const { } HloProto proto = MakeHloProto(module); return protobuf_util::DumpProtoToDirectory( - proto, xla_dump_unoptimized_hlo_proto_to, module.name()); + proto, xla_dump_unoptimized_hlo_proto_to, + StrCat(module.name(), ".unoptimized")); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h index 47d196fb2aaee897ce1fd3745129af10bf5b2d2d..8cf1a7b9f01fbb3572c6849c8b18e14174ced89f 100644 --- a/tensorflow/compiler/xla/service/service.h +++ b/tensorflow/compiler/xla/service/service.h @@ -21,6 +21,7 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/executable_run_options.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/service/allocation_tracker.h" @@ -37,7 +38,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla.pb.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -176,7 +176,7 @@ class Service : public ServiceInterface { // class. StatusOr> CreateModuleConfig( const ProgramShape& program_shape, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, const ExecutionOptions& execution_options); // Picks a parallel response and fills the result. @@ -191,7 +191,7 @@ class Service : public ServiceInterface { // Prepare the arguments for executing parallel. StatusOr>> GetArguments( const ExecutionOptions& execution_options, - tensorflow::gtl::ArraySlice arguments); + absl::Span arguments) const; protected: friend class LocalExecutable; @@ -207,14 +207,14 @@ class Service : public ServiceInterface { // the corresponding replica. StatusOr>> ResolveAndValidateArguments( - tensorflow::gtl::ArraySlice arguments, - tensorflow::gtl::ArraySlice stream_executors); + absl::Span arguments, + absl::Span stream_executors) const; // Create a Hlo module config for the given program shape and arguments. // execution_options is optional; if not given a default is used. StatusOr> CreateModuleConfig( const ProgramShape& program_shape, - tensorflow::gtl::ArraySlice argument_shapes, + absl::Span argument_shapes, const ExecutionOptions* execution_options); // Builds an Executable for the given parameters. @@ -242,21 +242,17 @@ class Service : public ServiceInterface { // ExecutionProfile object which will be filled in with profile data. StatusOr ExecuteAndRegisterResult( Executable* executable, - const tensorflow::gtl::ArraySlice> - arguments, + const absl::Span> arguments, Backend* backend, const string& result_tag, ExecutionProfile* profile); // Runs the given executables with the given arguments and register the result // from each executable in the allocation tracker. The handles of the result // from the tracker are returned. StatusOr> ExecuteParallelAndRegisterResult( - tensorflow::gtl::ArraySlice executables, - tensorflow::gtl::ArraySlice>> - arguments, - Backend* backend, - tensorflow::gtl::ArraySlice device_handles, - tensorflow::gtl::ArraySlice result_tags, - ExecutionProfile* profile); + absl::Span executables, + absl::Span>> arguments, + Backend* backend, absl::Span device_handles, + absl::Span result_tags, ExecutionProfile* profile); // Executes a single computation which has more than one target device. // The N devices are expected to all return an empty tuple, but one, which @@ -275,7 +271,9 @@ class Service : public ServiceInterface { StatusOr> Replicas( const Backend& backend, const DeviceHandle& device_handle) const; - Status MaybeDumpHloModule(const HloModule& module) const; + // Dumps the (unoptimized) module given if the corresponding DebugOptions + // field has been set. + Status MaybeDumpUnoptimizedHloModule(const HloModule& module) const; // Returns the device handle that represents the replicated device for a // single computation that is not model-parallelized. diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index 6a22f8bef493b3c270e210e1f9ea57fa79612a1d..25afc23e5b41468ad5dd1abed076e399cf20f350 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -22,7 +22,9 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_set.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -32,40 +34,36 @@ limitations under the License. #include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/math/math_util.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" namespace xla { namespace { +using absl::StrFormat; using absl::StrJoin; -using tensorflow::strings::Printf; // Returns true if no element is present in slice more than once. -bool AllUnique(tensorflow::gtl::ArraySlice slice) { +bool AllUnique(absl::Span slice) { return std::set(slice.begin(), slice.end()).size() == slice.size(); } Status ExpectArray(const Shape& shape, absl::string_view op_type) { if (!ShapeUtil::IsArray(shape)) { return InvalidArgument("Expected array argument for %s, but got %s.", - std::string(op_type).c_str(), - ShapeUtil::HumanString(shape).c_str()); + string(op_type), ShapeUtil::HumanString(shape)); } return Status::OK(); } -Status VerifyReducerShape( - const ProgramShape& reducer_shape, - tensorflow::gtl::ArraySlice init_value_shapes, - tensorflow::gtl::ArraySlice input_element_types, - int64 inputs) { +Status VerifyReducerShape(const ProgramShape& reducer_shape, + absl::Span init_value_shapes, + absl::Span input_element_types, + int64 inputs) { if (reducer_shape.parameters_size() != inputs * 2) { return InvalidArgument( - "Reduction function must take %lld parameters, but " + "Reduction function must take %d parameters, but " "takes %d parameter(s).", inputs * 2, reducer_shape.parameters_size()); } @@ -75,7 +73,7 @@ Status VerifyReducerShape( if (ShapeUtil::IsArray(accumulator_shape)) { if (inputs != 1) { return InvalidArgument( - "Reduction function must produce a tuple with %lld elements, but " + "Reduction function must produce a tuple with %d elements, but " "produces a scalar", inputs); } @@ -83,8 +81,8 @@ Status VerifyReducerShape( } else if (ShapeUtil::IsTuple(accumulator_shape)) { if (ShapeUtil::TupleElementCount(accumulator_shape) != inputs) { return InvalidArgument( - "Reduction function must produce a tuple with %lld elements, but has " - "%lld elements", + "Reduction function must produce a tuple with %d elements, but has " + "%d elements", inputs, ShapeUtil::TupleElementCount(accumulator_shape)); } for (const Shape& element_shape : accumulator_shape.tuple_shapes()) { @@ -94,7 +92,7 @@ Status VerifyReducerShape( return InvalidArgument( "Reduction function must produce a scalar or tuple of scalars, but has " "shape: %s", - ShapeUtil::HumanString(accumulator_shape).c_str()); + ShapeUtil::HumanString(accumulator_shape)); } for (const Shape* element_shape : accumulator_subshapes) { @@ -102,7 +100,7 @@ Status VerifyReducerShape( return InvalidArgument( "Reduction function must return a scalar or tuple of scalars but " "returns shape: %s", - ShapeUtil::HumanString(accumulator_shape).c_str()); + ShapeUtil::HumanString(accumulator_shape)); } } @@ -113,19 +111,19 @@ Status VerifyReducerShape( if (!ShapeUtil::Compatible(*accumulator_subshapes[i], reducer_shape.parameters(i))) { return InvalidArgument( - "Reduction function's %lld-th parameter shape differs from the " + "Reduction function's %d-th parameter shape differs from the " "result shape: %s vs %s", - i, ShapeUtil::HumanString(reducer_shape.parameters(i)).c_str(), - ShapeUtil::HumanString(*accumulator_subshapes[i]).c_str()); + i, ShapeUtil::HumanString(reducer_shape.parameters(i)), + ShapeUtil::HumanString(*accumulator_subshapes[i])); } // Check that init_value's shapes are suitable for reducer_shape. if (!ShapeUtil::CompatibleIgnoringFpPrecision(*accumulator_subshapes[i], *init_value_shapes[i])) { return InvalidArgument( - "Reduction function's accumulator shape at index %lld differs from " + "Reduction function's accumulator shape at index %d differs from " "the init_value shape: %s vs %s", - i, ShapeUtil::HumanString(*accumulator_subshapes[i]).c_str(), - ShapeUtil::HumanString(*init_value_shapes[i]).c_str()); + i, ShapeUtil::HumanString(*accumulator_subshapes[i]), + ShapeUtil::HumanString(*init_value_shapes[i])); } // Check that the inputs can be passed in as the non-accumulator arguments. const Shape input_element_shape = @@ -133,11 +131,11 @@ Status VerifyReducerShape( if (!ShapeUtil::CompatibleIgnoringFpPrecision( input_element_shape, reducer_shape.parameters(inputs + i))) { return InvalidArgument( - "Reduction function's %lld-th parameter shape differs from the " + "Reduction function's %d-th parameter shape differs from the " "input type element type: %s vs %s", inputs + i, - ShapeUtil::HumanString(reducer_shape.parameters(inputs + i)).c_str(), - ShapeUtil::HumanString(input_element_shape).c_str()); + ShapeUtil::HumanString(reducer_shape.parameters(inputs + i)), + ShapeUtil::HumanString(input_element_shape)); } // Check that the accumulator and inputs to the reducer function match. // If the accumulator is scalar, it must have the same type as the inputs @@ -147,11 +145,11 @@ Status VerifyReducerShape( if (!ShapeUtil::CompatibleIgnoringFpPrecision( *accumulator_subshapes[i], reducer_shape.parameters(inputs + i))) { return InvalidArgument( - "Reduction function's %lld-th parameter shape must " + "Reduction function's %d-th parameter shape must " "match the result shape, but got %s vs %s.", inputs + i, - ShapeUtil::HumanString(reducer_shape.parameters(inputs + i)).c_str(), - ShapeUtil::HumanString(*accumulator_subshapes[i]).c_str()); + ShapeUtil::HumanString(reducer_shape.parameters(inputs + i)), + ShapeUtil::HumanString(*accumulator_subshapes[i])); } } @@ -164,7 +162,7 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, bool allow_negative_padding) { if (window.dimensions_size() != ShapeUtil::Rank(base_shape)) { return InvalidArgument( - "Window has dimension %d but base shape has dimension %lld.", + "Window has dimension %d but base shape has dimension %d.", window.dimensions_size(), ShapeUtil::Rank(base_shape)); } @@ -173,29 +171,29 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, const auto& dim = window.dimensions(i); if (dim.size() <= 0) { return InvalidArgument("Window %s has a non-positive dimension.", - window.DebugString().c_str()); + window.DebugString()); } if (dim.stride() <= 0) { return InvalidArgument("Window %s has a non-positive stride.", - window.DebugString().c_str()); + window.DebugString()); } if (!allow_negative_padding && dim.padding_low() < 0) { return InvalidArgument("Window %s has a negative low padding.", - window.DebugString().c_str()); + window.DebugString()); } if (!allow_negative_padding && dim.padding_high() < 0) { return InvalidArgument("Window %s has a negative high padding.", - window.DebugString().c_str()); + window.DebugString()); } if (dim.base_dilation() < 1) { return InvalidArgument( "Window %s has a non-positive base area dilation factor.", - window.DebugString().c_str()); + window.DebugString()); } if (dim.window_dilation() < 1) { return InvalidArgument( "Window %s has a non-positive window dilation factor.", - window.DebugString().c_str()); + window.DebugString()); } const int64 dilated_base = window_util::DilatedBound( @@ -238,8 +236,7 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, return InvalidArgument( "Expected element type in shape to be floating for %s operation; " "got %s.", - HloOpcodeString(opcode).c_str(), - PrimitiveType_Name(shape.element_type()).c_str()); + HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type())); } return shape; case HloOpcode::kCos: @@ -254,8 +251,7 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, return InvalidArgument( "Expected element type in shape to be floating or complex for %s " "operation; got %s.", - HloOpcodeString(opcode).c_str(), - PrimitiveType_Name(shape.element_type()).c_str()); + HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type())); } return shape; case HloOpcode::kReal: @@ -268,8 +264,7 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, return InvalidArgument( "Expected element type in shape to be floating or complex for " "%s operation; got %s.", - HloOpcodeString(opcode).c_str(), - PrimitiveType_Name(shape.element_type()).c_str()); + HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type())); } case HloOpcode::kAbs: if (ShapeUtil::ElementIsComplex(shape)) { @@ -281,15 +276,14 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, return InvalidArgument( "Expected element type in shape to be floating or complex for " "%s operation; got %s.", - HloOpcodeString(opcode).c_str(), - PrimitiveType_Name(shape.element_type()).c_str()); + HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type())); } case HloOpcode::kClz: if (!ShapeUtil::ElementIsIntegral(shape)) { return InvalidArgument( "Expected an integral element type in argument to Clz " "operation; got %s.", - PrimitiveType_Name(shape.element_type()).c_str()); + PrimitiveType_Name(shape.element_type())); } return shape; case HloOpcode::kNegate: @@ -299,8 +293,7 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, return InvalidArgument( "Expected element type in shape to be integral, floating or " "complex for %s operation; got %s.", - HloOpcodeString(opcode).c_str(), - PrimitiveType_Name(shape.element_type()).c_str()); + HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type())); } return shape; case HloOpcode::kSign: @@ -309,8 +302,7 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, return InvalidArgument( "Expected element type in shape to be signed or complex for " "%s operation; got %s.", - HloOpcodeString(opcode).c_str(), - PrimitiveType_Name(shape.element_type()).c_str()); + HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type())); } return shape; @@ -320,7 +312,7 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, return InvalidArgument( "Expected pred or an integral element type in argument to Not " "operation; got %s.", - PrimitiveType_Name(shape.element_type()).c_str()); + PrimitiveType_Name(shape.element_type())); } return shape; @@ -330,25 +322,24 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, "Expected element type in shape to be floating " "point for IsFinite " "operation; got %s.", - PrimitiveType_Name(shape.element_type()).c_str()); + PrimitiveType_Name(shape.element_type())); } return ShapeUtil::ChangeElementType(shape, PRED); default: return InvalidArgument( "Unknown operation for unary shape inference: \"%s\".", - HloOpcodeString(opcode).c_str()); + HloOpcodeString(opcode)); } } /* static */ StatusOr ShapeInference::InferConcatOpShape( - tensorflow::gtl::ArraySlice arg_shapes, - const int64 dimension) { + absl::Span arg_shapes, const int64 dimension) { if (arg_shapes.empty()) { return InvalidArgument("Concatenate expects at least one argument."); } if (dimension < 0 || dimension >= ShapeUtil::Rank(*arg_shapes[0])) { - return InvalidArgument("Concatenate dimension out of bounds: %lld.", + return InvalidArgument("Concatenate dimension out of bounds: %d.", dimension); } const Shape* arg_shape = nullptr; @@ -362,17 +353,16 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, } if (ShapeUtil::Rank(*arg_shape) != ShapeUtil::Rank(*shape)) { return InvalidArgument( - "Cannot concatenate arrays with different ranks: %lld (%s) vs %lld " + "Cannot concatenate arrays with different ranks: %d (%s) vs %d " "(%s).", - ShapeUtil::Rank(*arg_shape), - ShapeUtil::HumanString(*arg_shape).c_str(), ShapeUtil::Rank(*shape), - ShapeUtil::HumanString(*shape).c_str()); + ShapeUtil::Rank(*arg_shape), ShapeUtil::HumanString(*arg_shape), + ShapeUtil::Rank(*shape), ShapeUtil::HumanString(*shape)); } if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(*arg_shape, *shape)) { return InvalidArgument( "Cannot concatenate arrays with different element types: %s vs %s.", - PrimitiveType_Name(arg_shape->element_type()).c_str(), - PrimitiveType_Name(shape->element_type()).c_str()); + PrimitiveType_Name(arg_shape->element_type()), + PrimitiveType_Name(shape->element_type())); } for (int64 dimension_number = 0; dimension_number < ShapeUtil::Rank(*arg_shape); ++dimension_number) { @@ -385,9 +375,9 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, return InvalidArgument( "Cannot concatenate arrays that differ in dimensions other than " "the one being concatenated (the other array dimensions must be " - "the same): %s vs %s in dimension %lld.", - ShapeUtil::HumanString(*arg_shape).c_str(), - ShapeUtil::HumanString(*shape).c_str(), dimension); + "the same): %s vs %s in dimension %d.", + ShapeUtil::HumanString(*arg_shape), ShapeUtil::HumanString(*shape), + dimension); } } element_type = ShapeUtil::HigherPrecisionElementType(*shape, *arg_shape); @@ -402,7 +392,7 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, } /* static */ StatusOr ShapeInference::InferAfterAllShape( - tensorflow::gtl::ArraySlice arg_shapes) { + absl::Span arg_shapes) { for (const Shape* arg_shape : arg_shapes) { if (arg_shape->element_type() != TOKEN) { return InvalidArgument( @@ -419,8 +409,8 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, !primitive_util::IsComplexType(new_element_type)) { return Unimplemented( "Conversion from complex to real type %s => %s is not implemented.", - ShapeUtil::HumanString(operand_shape).c_str(), - PrimitiveType_Name(new_element_type).c_str()); + ShapeUtil::HumanString(operand_shape), + PrimitiveType_Name(new_element_type)); } if (!ShapeUtil::IsArray(operand_shape) || !primitive_util::IsArrayType(new_element_type)) { @@ -429,8 +419,8 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, // are valid. For now we just reject them, though. return InvalidArgument( "Convert does not allow non-arrays, so cannot convert from %s to %s.", - ShapeUtil::HumanString(operand_shape).c_str(), - PrimitiveType_Name(new_element_type).c_str()); + ShapeUtil::HumanString(operand_shape), + PrimitiveType_Name(new_element_type)); } return ShapeUtil::ChangeElementType(operand_shape, new_element_type); @@ -442,8 +432,8 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, if (primitive_util::IsComplexType(old_element_type) != primitive_util::IsComplexType(new_element_type)) { return InvalidArgument("Conversion from complex to real type %s => %s.", - ShapeUtil::HumanString(operand_shape).c_str(), - PrimitiveType_Name(new_element_type).c_str()); + ShapeUtil::HumanString(operand_shape), + PrimitiveType_Name(new_element_type)); } if (!ShapeUtil::IsArray(operand_shape) || !primitive_util::IsArrayType(new_element_type)) { @@ -452,15 +442,15 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, // are valid. For now we just reject them, though. return InvalidArgument( "Cannot convert from or to tuple type; requested conversion: %s => %s.", - ShapeUtil::HumanString(operand_shape).c_str(), - PrimitiveType_Name(new_element_type).c_str()); + ShapeUtil::HumanString(operand_shape), + PrimitiveType_Name(new_element_type)); } if (primitive_util::BitWidth(old_element_type) != primitive_util::BitWidth(new_element_type)) { return InvalidArgument( "Cannot bitcast types with different bit-widths: %s => %s.", - PrimitiveType_Name(old_element_type).c_str(), - PrimitiveType_Name(new_element_type).c_str()); + PrimitiveType_Name(old_element_type), + PrimitiveType_Name(new_element_type)); } return ShapeUtil::ChangeElementType(operand_shape, new_element_type); @@ -473,7 +463,7 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, return InvalidArgument( "Expected element type in shape to be floating point for " "ReducePrecision operation; got %s.", - PrimitiveType_Name(operand_shape.element_type()).c_str()); + PrimitiveType_Name(operand_shape.element_type())); } if (exponent_bits < 1) { // One exponent bit is necessary to distinguish 0 from infinity. Having @@ -505,21 +495,29 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, return InvalidArgument( "The rank of the operand and the padding configuration do not match: " "%s vs %s.", - ShapeUtil::HumanString(operand_shape).c_str(), - padding_config.ShortDebugString().c_str()); + ShapeUtil::HumanString(operand_shape), + padding_config.ShortDebugString()); } if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(operand_shape, padding_value_shape)) { return InvalidArgument( "The element types of the operands to Pad do not match."); } + if (absl::c_any_of(padding_config.dimensions(), + [](const PaddingConfig::PaddingConfigDimension& p) { + return p.interior_padding() < 0; + })) { + return InvalidArgument("Interior padding cannot be negative: %s", + padding_config.ShortDebugString()); + } + std::vector dimensions(ShapeUtil::Rank(operand_shape)); for (int64 i = 0; i < operand_shape.dimensions_size(); ++i) { - dimensions[i] = operand_shape.dimensions(i) + - padding_config.dimensions(i).edge_padding_low() + - padding_config.dimensions(i).edge_padding_high() + + const auto& p = padding_config.dimensions(i); + dimensions[i] = operand_shape.dimensions(i) + p.edge_padding_low() + + p.edge_padding_high() + std::max(operand_shape.dimensions(i) - 1, 0LL) * - padding_config.dimensions(i).interior_padding(); + p.interior_padding(); } return ShapeUtil::MakeShape( ShapeUtil::HigherPrecisionElementType(operand_shape, padding_value_shape), @@ -550,22 +548,22 @@ Status ValidateDotDimensionNumbers( const Shape& lhs, const Shape& rhs, const DotDimensionNumbers& dimension_numbers) { // Check that dimension numbers are in range. - auto dims_in_range = - [](const int64 rank, tensorflow::gtl::ArraySlice contracting_dims, - tensorflow::gtl::ArraySlice batch_dims) -> bool { + auto dims_in_range = [](const int64 rank, + absl::Span contracting_dims, + absl::Span batch_dims) -> bool { auto in_range = [&rank](int64 i) -> bool { return 0 <= i && i < rank; }; return std::all_of(contracting_dims.begin(), contracting_dims.end(), in_range) && std::all_of(batch_dims.begin(), batch_dims.end(), in_range); }; - tensorflow::gtl::ArraySlice lhs_contracting_dimensions = + absl::Span lhs_contracting_dimensions = AsInt64Slice(dimension_numbers.lhs_contracting_dimensions()); - tensorflow::gtl::ArraySlice rhs_contracting_dimensions = + absl::Span rhs_contracting_dimensions = AsInt64Slice(dimension_numbers.rhs_contracting_dimensions()); - tensorflow::gtl::ArraySlice lhs_batch_dimensions = + absl::Span lhs_batch_dimensions = AsInt64Slice(dimension_numbers.lhs_batch_dimensions()); - tensorflow::gtl::ArraySlice rhs_batch_dimensions = + absl::Span rhs_batch_dimensions = AsInt64Slice(dimension_numbers.rhs_batch_dimensions()); if (!dims_in_range(ShapeUtil::Rank(lhs), lhs_contracting_dimensions, @@ -573,13 +571,13 @@ Status ValidateDotDimensionNumbers( !dims_in_range(ShapeUtil::Rank(rhs), rhs_contracting_dimensions, rhs_batch_dimensions)) { return InvalidArgument("A dimension number is out of range in Dot: %s.", - dimension_numbers.DebugString().c_str()); + dimension_numbers.DebugString()); } // Check that dimension numbers are unique. - auto dims_unique = [](tensorflow::gtl::ArraySlice contracting_dims, - tensorflow::gtl::ArraySlice batch_dims) -> bool { - tensorflow::gtl::FlatSet dim_set; + auto dims_unique = [](absl::Span contracting_dims, + absl::Span batch_dims) -> bool { + absl::flat_hash_set dim_set; auto is_unique = [&dim_set](int64 i) -> bool { return dim_set.insert(i).second; }; @@ -591,7 +589,7 @@ Status ValidateDotDimensionNumbers( if (!dims_unique(lhs_contracting_dimensions, lhs_batch_dimensions) || !dims_unique(rhs_contracting_dimensions, rhs_batch_dimensions)) { return InvalidArgument("A dimension number is not unique in Dot: %s.", - dimension_numbers.DebugString().c_str()); + dimension_numbers.DebugString()); } // Check that the count of non-contracting-non-batch dimensions is in {0, 1}. @@ -636,14 +634,13 @@ Status ValidateDotDimensionNumbers( TF_RETURN_IF_ERROR(ExpectArray(rhs, "rhs of dot")); auto fail = [lhs, rhs](const string& addendum) -> Status { - string message = tensorflow::strings::Printf( - "Cannot infer shape for dot operation: %s %s.", - ShapeUtil::HumanString(lhs).c_str(), - ShapeUtil::HumanString(rhs).c_str()); + string message = + StrFormat("Cannot infer shape for dot operation: %s %s.", + ShapeUtil::HumanString(lhs), ShapeUtil::HumanString(rhs)); if (!addendum.empty()) { message += " " + addendum; } - return InvalidArgument("%s", message.c_str()); + return InvalidArgument("%s", message); }; // Check if both element types are the same. @@ -739,9 +736,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } else { return InvalidArgument( "Binary op %s with incompatible shapes: %s and %s.", - HloOpcodeString(operation).c_str(), - ShapeUtil::HumanString(lhs).c_str(), - ShapeUtil::HumanString(rhs).c_str()); + HloOpcodeString(operation), ShapeUtil::HumanString(lhs), + ShapeUtil::HumanString(rhs)); } } return ShapeUtil::MakeShape(ShapeUtil::HigherPrecisionElementType(lhs, rhs), @@ -750,20 +746,20 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, /* static */ StatusOr ShapeInference::InferInDimBroadcastShape( const Shape& smaller_shape, const Shape& larger_shape, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { if (broadcast_dimensions.empty() && !ShapeUtil::IsScalar(smaller_shape)) { // Reject "magic" inference for binops on different shapes, requiring // the user to provide an explicit broadcast dimension in this case. // See b/25177275 for more details. return InvalidArgument("Automatic shape inference not supported: %s and %s", - ShapeUtil::HumanString(smaller_shape).c_str(), - ShapeUtil::HumanString(larger_shape).c_str()); + ShapeUtil::HumanString(smaller_shape), + ShapeUtil::HumanString(larger_shape)); } else if (broadcast_dimensions.size() != ShapeUtil::Rank(smaller_shape)) { return InvalidArgument( "Size of broadcast_dimensions has to match lower-rank operand's " "rank; " - " lower-rank operand's rank is %lld, size of broadcast_dimensions is " - "%zu.", + " lower-rank operand's rank is %d, size of broadcast_dimensions is " + "%u.", ShapeUtil::Rank(smaller_shape), broadcast_dimensions.size()); } @@ -813,12 +809,12 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, int64 dimension_to_match = broadcast_dimensions.at(i); if (dimension_to_match < 0) { return InvalidArgument( - "Broadcast dimension number (%lld) cannot be negative.", + "Broadcast dimension number (%d) cannot be negative.", dimension_to_match); } if (dimension_to_match >= larger_shape.dimensions_size()) { return InvalidArgument( - "Broadcast dimension number (%lld) too large; higher-rank " + "Broadcast dimension number (%d) too large; higher-rank " "operand has rank %d.", dimension_to_match, larger_shape.dimensions_size()); } @@ -830,16 +826,16 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, if (small_dimension_size != large_dimension_size && small_dimension_size != 1 && large_dimension_size != 1) { return InvalidArgument( - "Broadcast dimension %d mismatch: %lld != %lld; %s and %s.", i, + "Broadcast dimension %d mismatch: %d != %d; %s and %s.", i, small_dimension_size, large_dimension_size, - ShapeUtil::HumanString(smaller_shape).c_str(), - ShapeUtil::HumanString(larger_shape).c_str()); + ShapeUtil::HumanString(smaller_shape), + ShapeUtil::HumanString(larger_shape)); } // Make sure the broadcast dimensions are listed in a strictly increasing // order. if (i > 0 && broadcast_dimensions.at(i - 1) >= dimension_to_match) { return InvalidArgument( - "Broadcast dimensions order is wrong: %lld comes after %lld.", + "Broadcast dimensions order is wrong: %d comes after %d.", dimension_to_match, broadcast_dimensions.at(i - 1)); } @@ -851,15 +847,15 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, /* static */ StatusOr ShapeInference::InferElementwiseBinaryOpShape( HloOpcode operation, const Shape& lhs, const Shape& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { TF_RETURN_IF_ERROR(ExpectArray(lhs, "lhs of elementwise binary operation")); TF_RETURN_IF_ERROR(ExpectArray(rhs, "rhs of elementwise binary operation")); if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) { return InvalidArgument( "Binary op %s with different element types: %s and %s.", - HloOpcodeString(operation).c_str(), ShapeUtil::HumanString(lhs).c_str(), - ShapeUtil::HumanString(rhs).c_str()); + HloOpcodeString(operation), ShapeUtil::HumanString(lhs), + ShapeUtil::HumanString(rhs)); } if (ShapeUtil::Rank(lhs) == ShapeUtil::Rank(rhs)) { @@ -908,12 +904,11 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, /* static */ StatusOr ShapeInference::InferBinaryOpShape( HloOpcode opcode, const Shape& lhs, const Shape& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - VLOG(2) << tensorflow::strings::Printf( + absl::Span broadcast_dimensions) { + VLOG(2) << StrFormat( "inferring shape for <%s>(%s, %s) with broadcast_dimensions={%s}", - HloOpcodeString(opcode).c_str(), ShapeUtil::HumanString(lhs).c_str(), - ShapeUtil::HumanString(rhs).c_str(), - StrJoin(broadcast_dimensions, ", ").c_str()); + HloOpcodeString(opcode), ShapeUtil::HumanString(lhs), + ShapeUtil::HumanString(rhs), StrJoin(broadcast_dimensions, ", ")); TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(lhs)); TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(rhs)); @@ -924,6 +919,9 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, switch (opcode) { case HloOpcode::kMaximum: case HloOpcode::kMinimum: + return InferElementwiseBinaryOpShape(opcode, lhs, rhs, + broadcast_dimensions); + case HloOpcode::kSubtract: case HloOpcode::kAdd: case HloOpcode::kAtan2: @@ -934,6 +932,12 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, case HloOpcode::kShiftLeft: case HloOpcode::kShiftRightArithmetic: case HloOpcode::kShiftRightLogical: + if (lhs.element_type() == PRED || rhs.element_type() == PRED) { + return InvalidArgument( + "Expected element type in shape to be arithmetic type for " + "operation %s; got PRED.", + HloOpcodeString(opcode)); + } return InferElementwiseBinaryOpShape(opcode, lhs, rhs, broadcast_dimensions); @@ -942,7 +946,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument( "Expected element type in shape to be floating for complex compose " "operation; got %s.", - PrimitiveType_Name(lhs.element_type()).c_str()); + PrimitiveType_Name(lhs.element_type())); } TF_ASSIGN_OR_RETURN(const Shape& shape, InferElementwiseBinaryOpShape(opcode, lhs, rhs, @@ -961,7 +965,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument( "Expected pred or integral type in argument to and/or operation; " "got %s.", - PrimitiveType_Name(lhs.element_type()).c_str()); + PrimitiveType_Name(lhs.element_type())); } return InferElementwiseBinaryOpShape(opcode, lhs, rhs, broadcast_dimensions); @@ -979,8 +983,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, default: return Unimplemented( "Binary op shape inference: %s; lhs: %s; rhs: %s is not implemented.", - HloOpcodeString(opcode).c_str(), lhs.ShortDebugString().c_str(), - rhs.ShortDebugString().c_str()); + HloOpcodeString(opcode), lhs.ShortDebugString(), + rhs.ShortDebugString()); } } @@ -1003,14 +1007,12 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, case HloOpcode::kTupleSelect: return InferTupleSelectShape(lhs, rhs, ehs); default: - return InvalidArgument("Unknown operation %s.", - HloOpcodeString(opcode).c_str()); + return InvalidArgument("Unknown operation %s.", HloOpcodeString(opcode)); } } /* static */ StatusOr ShapeInference::InferVariadicOpShape( - HloOpcode opcode, - tensorflow::gtl::ArraySlice operands) { + HloOpcode opcode, absl::Span operands) { std::vector operand_shapes; operand_shapes.reserve(operands.size()); for (const HloInstruction* operand : operands) { @@ -1020,8 +1022,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } /* static */ StatusOr ShapeInference::InferVariadicOpShape( - HloOpcode opcode, - tensorflow::gtl::ArraySlice operand_shapes) { + HloOpcode opcode, absl::Span operand_shapes) { for (const Shape* shape : operand_shapes) { TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(*shape)); } @@ -1037,30 +1038,33 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, case HloOpcode::kSort: { if (operand_shapes.size() == 1) { return *operand_shapes[0]; - } else if (operand_shapes.size() == 2) { - if (!ShapeUtil::SameDimensions(*operand_shapes[0], - *operand_shapes[1])) { - return InvalidArgument( - "Sort keys and values dimensions must match. " - "Keys shape is: %s\n, Values shape is: %s", - ShapeUtil::HumanString(*operand_shapes[0]).c_str(), - ShapeUtil::HumanString(*operand_shapes[1]).c_str()); + } else { + for (int64 operand = 1; operand < operand_shapes.size(); ++operand) { + if (!ShapeUtil::SameDimensions(*operand_shapes[0], + *operand_shapes[operand])) { + return InvalidArgument( + "Sort keys and values dimensions must match. " + "Keys shape is: %s\n, Values shape (operand index %lld) is: %s", + ShapeUtil::HumanString(*operand_shapes[0]), operand, + ShapeUtil::HumanString(*operand_shapes[operand])); + } + } + std::vector operand_shape_values; + for (const Shape* operand_shape : operand_shapes) { + operand_shape_values.push_back(*operand_shape); } - return ShapeUtil::MakeTupleShape( - {*operand_shapes[0], *operand_shapes[1]}); + return ShapeUtil::MakeTupleShape(operand_shape_values); } return InvalidArgument("Unexpected number of operands for sort"); } default: - return InvalidArgument("Unknown operation %s.", - HloOpcodeString(opcode).c_str()); + return InvalidArgument("Unknown operation %s.", HloOpcodeString(opcode)); } } /* static */ StatusOr ShapeInference::InferMapShape( - tensorflow::gtl::ArraySlice arg_shapes, - const ProgramShape& to_apply, - tensorflow::gtl::ArraySlice dimensions) { + absl::Span arg_shapes, const ProgramShape& to_apply, + absl::Span dimensions) { if (arg_shapes.empty()) { return InvalidArgument("Map expects at least one argument."); } @@ -1091,7 +1095,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument( "Map operation requires all operands to have the same shape; got: " "%s.", - StrJoin(pieces, ", ").c_str()); + StrJoin(pieces, ", ")); } // Check that dimensions.size == arg_shape.dimensions_size() (we currently @@ -1099,7 +1103,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, if (dimensions.size() != arg_shape->dimensions_size()) { return InvalidArgument( "Map applied to a subset of dimensions currently not supported: " - "arg_dimension_size: %d, requested_map_dimensions_size: %zu.", + "arg_dimension_size: %d, requested_map_dimensions_size: %u.", arg_shape->dimensions_size(), dimensions.size()); } @@ -1108,7 +1112,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, if (dimensions[i] != i) { return InvalidArgument( "Map requires monotonically increasing dimension numbers; got: %s.", - StrJoin(dimensions, ", ").c_str()); + StrJoin(dimensions, ", ")); } } @@ -1116,7 +1120,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, if (arg_shapes.size() != to_apply.parameters_size()) { return InvalidArgument( "Map applied function arity must match number of arguments; got: " - "arity: %d, arguments: %zu.", + "arity: %d, arguments: %u.", to_apply.parameters_size(), arg_shapes.size()); } @@ -1125,7 +1129,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, if (!ShapeUtil::IsScalar(output_shape)) { return InvalidArgument( "Mapped computation's result has to be a scalar; got: %s.", - ShapeUtil::HumanString(output_shape).c_str()); + ShapeUtil::HumanString(output_shape)); } for (int i = 0; i < to_apply.parameters_size(); ++i) { @@ -1135,7 +1139,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument( "Mapped computation's parameter has to be a scalar; " "got parameter %d shape: %s.", - i, ShapeUtil::HumanString(parameter_shape).c_str()); + i, ShapeUtil::HumanString(parameter_shape)); } if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(parameter_shape, @@ -1143,8 +1147,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument( "Mapped computation's parameter type has to match argument element " "type; got parameter %d shape: %s, argument shape: %s.", - i, ShapeUtil::HumanString(parameter_shape).c_str(), - ShapeUtil::HumanString(*arg_shape).c_str()); + i, ShapeUtil::HumanString(parameter_shape), + ShapeUtil::HumanString(*arg_shape)); } } @@ -1173,35 +1177,35 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument( "Expected feature_index of batch-norm-training to be " "smaller than the rank of operand_shape; " - "got feature_index %lld, and rank %lld.", + "got feature_index %d, and rank %d.", feature_index, ShapeUtil::Rank(operand_shape)); } if (feature_index < 0) { return InvalidArgument( "Expected feature_index of batch-norm-training to " - "be a non-negative number, got %lld.", + "be a non-negative number, got %d.", feature_index); } if (ShapeUtil::Rank(operand_shape) < 1) { return InvalidArgument( "Expected the rank of operand to " - "batch-norm-training to be at least 1; got %lld.", + "batch-norm-training to be at least 1; got %d.", ShapeUtil::Rank(operand_shape)); } if (ShapeUtil::Rank(offset_shape) != 1) { return InvalidArgument( "Offset input of batch-norm-training must have" - " rank 1, but has rank %lld.", + " rank 1, but has rank %d.", ShapeUtil::Rank(offset_shape)); } if (ShapeUtil::Rank(scale_shape) != 1) { return InvalidArgument( "Scale input of batch-norm-training must have" - " rank 1, but has rank %lld.", + " rank 1, but has rank %d.", ShapeUtil::Rank(scale_shape)); } @@ -1209,7 +1213,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument( "The operand to batch-norm-training must have a floating point " "element type, but the shape is %s.", - PrimitiveType_Name(operand_shape.element_type()).c_str()); + PrimitiveType_Name(operand_shape.element_type())); } if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(offset_shape, @@ -1218,8 +1222,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, "The inputs should have the same element type for batch-norm-training, " "but the shape of offset factor is %s " "and the shape of operand is %s.", - PrimitiveType_Name(offset_shape.element_type()).c_str(), - PrimitiveType_Name(operand_shape.element_type()).c_str()); + PrimitiveType_Name(offset_shape.element_type()), + PrimitiveType_Name(operand_shape.element_type())); } if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(scale_shape, @@ -1228,8 +1232,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, "The inputs should have the same element type for batch-norm-training, " "but the shape of scale factor is %s " "and the shape of operand is %s.", - PrimitiveType_Name(scale_shape.element_type()).c_str(), - PrimitiveType_Name(operand_shape.element_type()).c_str()); + PrimitiveType_Name(scale_shape.element_type()), + PrimitiveType_Name(operand_shape.element_type())); } const int64 feature_count = operand_shape.dimensions(feature_index); @@ -1239,16 +1243,16 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, if (ShapeUtil::GetDimension(offset_shape, 0) != feature_count) { return InvalidArgument( "The size of offset factor should be the same as feature count," - "but the size of offset factor is %lld " - "and the feature count is %lld.", + "but the size of offset factor is %d " + "and the feature count is %d.", ShapeUtil::GetDimension(offset_shape, 0), feature_count); } if (ShapeUtil::GetDimension(scale_shape, 0) != feature_count) { return InvalidArgument( "The size of scale factor should be the same as feature count," - "but the size of scale factor is %lld " - "and the feature count is %lld.", + "but the size of scale factor is %d " + "and the feature count is %d.", ShapeUtil::GetDimension(scale_shape, 0), feature_count); } @@ -1283,35 +1287,35 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument( "Expected feature_index of batch-norm-inference to be " "smaller than the rank of operand_shape; " - "got feature_index %lld, and rank %lld.", + "got feature_index %d, and rank %d.", feature_index, ShapeUtil::Rank(operand_shape)); } if (feature_index < 0) { return InvalidArgument( "Expected feature_index of batch-norm-inference to " - "be a non-negative number, got %lld.", + "be a non-negative number, got %d.", feature_index); } if (ShapeUtil::Rank(operand_shape) < 1) { return InvalidArgument( "Expected the rank of operand to " - "batch-norm-inference to be at least 1; got %lld.", + "batch-norm-inference to be at least 1; got %d.", ShapeUtil::Rank(operand_shape)); } if (ShapeUtil::Rank(offset_shape) != 1) { return InvalidArgument( "Offset input of batch-norm-inference must have" - " rank 1, but has rank %lld.", + " rank 1, but has rank %d.", ShapeUtil::Rank(offset_shape)); } if (ShapeUtil::Rank(scale_shape) != 1) { return InvalidArgument( "Scale input of batch-norm-inference must have" - " rank 1, but has rank %lld.", + " rank 1, but has rank %d.", ShapeUtil::Rank(scale_shape)); } @@ -1319,7 +1323,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument( "The operand to batch-norm-inference must have a floating point " "element type, but the shape is %s.", - PrimitiveType_Name(operand_shape.element_type()).c_str()); + PrimitiveType_Name(operand_shape.element_type())); } if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(offset_shape, @@ -1329,8 +1333,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, "batch-norm-inference, " "but the shape of offset factor is %s " "and the shape of operand is %s.", - PrimitiveType_Name(offset_shape.element_type()).c_str(), - PrimitiveType_Name(operand_shape.element_type()).c_str()); + PrimitiveType_Name(offset_shape.element_type()), + PrimitiveType_Name(operand_shape.element_type())); } if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(scale_shape, @@ -1340,8 +1344,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, "batch-norm-inference, " "but the shape of scale factor is %s " "and the shape of operand is %s.", - PrimitiveType_Name(scale_shape.element_type()).c_str(), - PrimitiveType_Name(operand_shape.element_type()).c_str()); + PrimitiveType_Name(scale_shape.element_type()), + PrimitiveType_Name(operand_shape.element_type())); } if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(mean_shape, @@ -1351,8 +1355,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, "batch-norm-inference, " "but the shape of mean is %s " "and the shape of operand is %s.", - PrimitiveType_Name(mean_shape.element_type()).c_str(), - PrimitiveType_Name(operand_shape.element_type()).c_str()); + PrimitiveType_Name(mean_shape.element_type()), + PrimitiveType_Name(operand_shape.element_type())); } if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(variance_shape, @@ -1362,8 +1366,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, "batch-norm-inference, " "but the shape of variance is %s " "and the shape of operand is %s.", - PrimitiveType_Name(mean_shape.element_type()).c_str(), - PrimitiveType_Name(variance_shape.element_type()).c_str()); + PrimitiveType_Name(mean_shape.element_type()), + PrimitiveType_Name(variance_shape.element_type())); } const int64 feature_count = operand_shape.dimensions(feature_index); @@ -1373,32 +1377,32 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, if (ShapeUtil::GetDimension(offset_shape, 0) != feature_count) { return InvalidArgument( "The size of offset factor should be the same as feature count," - "but the size of offset factor is %lld " - "and the feature count is %lld.", + "but the size of offset factor is %d " + "and the feature count is %d.", ShapeUtil::GetDimension(offset_shape, 0), feature_count); } if (ShapeUtil::GetDimension(scale_shape, 0) != feature_count) { return InvalidArgument( "The size of scale factor should be the same as feature count," - "but the size of scale factor is %lld " - "and the feature count is %lld.", + "but the size of scale factor is %d " + "and the feature count is %d.", ShapeUtil::GetDimension(scale_shape, 0), feature_count); } if (ShapeUtil::GetDimension(mean_shape, 0) != feature_count) { return InvalidArgument( "The size of mean should be the same as feature count," - "but the size of mean is %lld " - "and the feature count is %lld.", + "but the size of mean is %d " + "and the feature count is %d.", ShapeUtil::GetDimension(mean_shape, 0), feature_count); } if (ShapeUtil::GetDimension(variance_shape, 0) != feature_count) { return InvalidArgument( "The size of variance should be the same as feature count," - "but the size of variance is %lld " - "and the feature count is %lld.", + "but the size of variance is %d " + "and the feature count is %d.", ShapeUtil::GetDimension(variance_shape, 0), feature_count); } @@ -1428,36 +1432,36 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument( "Expected feature_index of batch-norm-grad to be " "smaller than the rank of operand_shape; " - "got feature_index %lld, and rank %lld.", + "got feature_index %d, and rank %d.", feature_index, ShapeUtil::Rank(operand_shape)); } if (ShapeUtil::Rank(operand_shape) != ShapeUtil::Rank(output_grad_shape)) { return InvalidArgument( "Expected operand_shape of batch-norm-grad to have the same rank as" - " output_grad_shape; got rank(oprand_shape) %lld, and" - " rank(output_grad_shape) %lld.", + " output_grad_shape; got rank(oprand_shape) %d, and" + " rank(output_grad_shape) %d.", ShapeUtil::Rank(operand_shape), ShapeUtil::Rank(output_grad_shape)); } if (ShapeUtil::Rank(mean_shape) != 1) { return InvalidArgument( "Mean input of batch-norm-grad must have" - " rank 1, but has rank %lld.", + " rank 1, but has rank %d.", ShapeUtil::Rank(mean_shape)); } if (ShapeUtil::Rank(scale_shape) != 1) { return InvalidArgument( "Scale input of batch-norm-grad must have" - " rank 1, but has rank %lld.", + " rank 1, but has rank %d.", ShapeUtil::Rank(scale_shape)); } if (ShapeUtil::Rank(var_shape) != 1) { return InvalidArgument( "Var input of batch-norm-grad must have" - " rank 1, but has rank %lld.", + " rank 1, but has rank %d.", ShapeUtil::Rank(var_shape)); } @@ -1465,14 +1469,14 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument( "The operand to batch-norm-grad must have a floating point " "element type, but the shape is %s.", - PrimitiveType_Name(operand_shape.element_type()).c_str()); + PrimitiveType_Name(operand_shape.element_type())); } if (!ShapeUtil::ElementIsFloating(output_grad_shape)) { return InvalidArgument( "The output_grad to batch-norm-grad must have a floating point " "element type, but the shape is %s.", - PrimitiveType_Name(output_grad_shape.element_type()).c_str()); + PrimitiveType_Name(output_grad_shape.element_type())); } if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(output_grad_shape, @@ -1481,8 +1485,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, "The inputs should have the same element type for batch-norm-grad, " "but the element type of output_grad is %s " "and the element type of operand is %s.", - PrimitiveType_Name(output_grad_shape.element_type()).c_str(), - PrimitiveType_Name(operand_shape.element_type()).c_str()); + PrimitiveType_Name(output_grad_shape.element_type()), + PrimitiveType_Name(operand_shape.element_type())); } if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(scale_shape, @@ -1491,8 +1495,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, "The inputs should have the same element type for batch-norm-grad, " "but the element type of scale factor is %s " "and the element type of operand is %s.", - PrimitiveType_Name(scale_shape.element_type()).c_str(), - PrimitiveType_Name(operand_shape.element_type()).c_str()); + PrimitiveType_Name(scale_shape.element_type()), + PrimitiveType_Name(operand_shape.element_type())); } if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(mean_shape, @@ -1501,8 +1505,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, "The inputs should have the same element type for batch-norm-grad, " "but the element type of mean is %s " "and the element type of operand is %s.", - PrimitiveType_Name(mean_shape.element_type()).c_str(), - PrimitiveType_Name(operand_shape.element_type()).c_str()); + PrimitiveType_Name(mean_shape.element_type()), + PrimitiveType_Name(operand_shape.element_type())); } if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(var_shape, @@ -1511,8 +1515,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, "The inputs should have the same element type for batch-norm-grad, " "but the element type of mean is %s " "and the element type of operand is %s.", - PrimitiveType_Name(mean_shape.element_type()).c_str(), - PrimitiveType_Name(operand_shape.element_type()).c_str()); + PrimitiveType_Name(mean_shape.element_type()), + PrimitiveType_Name(operand_shape.element_type())); } const int64 feature_count = operand_shape.dimensions(feature_index); @@ -1523,24 +1527,24 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, if (ShapeUtil::GetDimension(mean_shape, 0) != feature_count) { return InvalidArgument( "The size of mean should be the same as feature count," - "but the size of offset factor is %lld " - "and the feature count is %lld.", + "but the size of offset factor is %d " + "and the feature count is %d.", ShapeUtil::GetDimension(mean_shape, 0), feature_count); } if (ShapeUtil::GetDimension(scale_shape, 0) != feature_count) { return InvalidArgument( "The size of scale factor should be the same as feature count," - "but the size of scale factor is %lld " - "and the feature count is %lld.", + "but the size of scale factor is %d " + "and the feature count is %d.", ShapeUtil::GetDimension(scale_shape, 0), feature_count); } if (ShapeUtil::GetDimension(var_shape, 0) != feature_count) { return InvalidArgument( "The size of variance should be the same as feature count," - "but the size of variance is %lld " - "and the feature count is %lld.", + "but the size of variance is %d " + "and the feature count is %d.", ShapeUtil::GetDimension(var_shape, 0), feature_count); } @@ -1550,8 +1554,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, ShapeUtil::GetDimension(output_grad_shape, i)) { return InvalidArgument( "The bounds of operand shape should be the same as output_grad's," - "but the bound of operand_shape at dimension %lld is %lld " - "and the bound of output_grad_shape is %lld.", + "but the bound of operand_shape at dimension %d is %d " + "and the bound of output_grad_shape is %d.", i, ShapeUtil::GetDimension(operand_shape, i), ShapeUtil::GetDimension(output_grad_shape, i)); } @@ -1562,23 +1566,22 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } /* static */ StatusOr ShapeInference::InferConvolveShape( - const Shape& lhs, const Shape& rhs, const Window& window, - const ConvolutionDimensionNumbers& dnums, int64 feature_group_count) { + const Shape& lhs, const Shape& rhs, int64 feature_group_count, + const Window& window, const ConvolutionDimensionNumbers& dnums) { TF_RETURN_IF_ERROR(ExpectArray(lhs, "lhs of convolution")); TF_RETURN_IF_ERROR(ExpectArray(rhs, "rhs of convolution")); if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) { return InvalidArgument( "Convolution with different element types: %s and %s.", - ShapeUtil::HumanString(lhs).c_str(), - ShapeUtil::HumanString(rhs).c_str()); + ShapeUtil::HumanString(lhs), ShapeUtil::HumanString(rhs)); } if (dnums.input_spatial_dimensions_size() != dnums.kernel_spatial_dimensions_size()) { return InvalidArgument( "Both arguments to convolution must have same number of dimensions.\n" "Window: %s", - window.DebugString().c_str()); + window.DebugString()); } const int num_spatial_dims = dnums.input_spatial_dimensions_size(); @@ -1586,19 +1589,19 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument( "Window must have same number of dimensions as dimension numbers.\n" "Window: %s\nDimension numbers: %s.", - window.DebugString().c_str(), dnums.DebugString().c_str()); + window.DebugString(), dnums.DebugString()); } const int num_dims = num_spatial_dims + 2; if (ShapeUtil::Rank(lhs) != num_dims) { return InvalidArgument( "The LHS argument to a convolution should have rank %d; lhs: %s.", - num_dims, ShapeUtil::HumanString(lhs).c_str()); + num_dims, ShapeUtil::HumanString(lhs)); } if (ShapeUtil::Rank(rhs) != num_dims) { return InvalidArgument( "The RHS argument to a convolution should have rank %d; lhs: %s.", - num_dims, ShapeUtil::HumanString(lhs).c_str()); + num_dims, ShapeUtil::HumanString(lhs)); } TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(lhs)); TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(rhs)); @@ -1635,26 +1638,26 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, !std::all_of(output_dnums.begin(), output_dnums.end(), in_range)) { return InvalidArgument( "A dimension number is out of range in convolution: %s.", - dnums.DebugString().c_str()); + dnums.DebugString()); } if (input_dnums != expected_dnums) { return InvalidArgument( "Input dimensions of convolution must contain each dimension exactly " "once: %s.", - dnums.DebugString().c_str()); + dnums.DebugString()); } if (window_dnums != expected_dnums) { return InvalidArgument( "Window dimensions of convolution must contain each dimension exactly " "once: %s.", - dnums.DebugString().c_str()); + dnums.DebugString()); } if (output_dnums != expected_dnums) { return InvalidArgument( "Output dimensions of convolution must contain each dimension exactly " "once: %s.", - dnums.DebugString().c_str()); + dnums.DebugString()); } std::vector input_spatial_dims(num_spatial_dims); @@ -1675,13 +1678,24 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, if (input_features != kernel_input_features * feature_group_count) { return InvalidArgument( - "Expected LHS feature dimension (value %lld) to match RHS " - "input feature dimension * feature_group_count (value %lld); " + "Expected LHS feature dimension (value %d) to match RHS " + "input feature dimension * feature_group_count (value %d * %d = %d); " "got (%s, %s)\n" "Dimension numbers: {%s}.", - input_features, kernel_input_features * feature_group_count, - ShapeUtil::HumanString(lhs).c_str(), - ShapeUtil::HumanString(rhs).c_str(), dnums.DebugString().c_str()); + input_features, kernel_input_features, feature_group_count, + kernel_input_features * feature_group_count, + ShapeUtil::HumanString(lhs), ShapeUtil::HumanString(rhs), + dnums.DebugString()); + } + if (kernel_output_features % feature_group_count > 0) { + return InvalidArgument( + "Expected output feature dimension (value %d) to be divisible by " + "feature_group_count (value %d); " + "got (%s, %s)\n" + "Dimension numbers: {%s}.", + kernel_output_features, feature_group_count, + ShapeUtil::HumanString(lhs), ShapeUtil::HumanString(rhs), + dnums.DebugString()); } std::vector window_dims(num_spatial_dims); for (int i = 0; i < num_spatial_dims; ++i) { @@ -1693,8 +1707,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, "RHS shape: %s\n\t" "Window: {%s}\n\t" "Dimension numbers: {%s}.", - ShapeUtil::HumanString(rhs).c_str(), window.ShortDebugString().c_str(), - dnums.ShortDebugString().c_str()); + ShapeUtil::HumanString(rhs), window.ShortDebugString(), + dnums.ShortDebugString()); } Shape base_shape = @@ -1717,32 +1731,32 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, /* static */ StatusOr ShapeInference::InferFftShape( const Shape& in, const FftType fft_type, - const tensorflow::gtl::ArraySlice fft_length) { + const absl::Span fft_length) { const int64 fft_rank = fft_length.size(); if (fft_rank < 1 || fft_rank > 3) { - return InvalidArgument("FFT only supports ranks 1-3; got %lld.", fft_rank); + return InvalidArgument("FFT only supports ranks 1-3; got %d.", fft_rank); } -#define RET_CHECK_RANK(x) \ - if (x.dimensions_size() < fft_rank) { \ - return InvalidArgument( \ - "FFT of rank %lld requires input of at least " \ - "same rank; got input of rank %d", \ - fft_rank, x.dimensions_size()); \ +#define RET_CHECK_RANK(x) \ + if (x.dimensions_size() < fft_rank) { \ + return InvalidArgument( \ + "FFT of rank %d requires input of at least " \ + "same rank; got input of rank %d", \ + fft_rank, x.dimensions_size()); \ } switch (fft_type) { case FFT: case IFFT: if (in.element_type() != C64) { return InvalidArgument("%s requires C64 input type, found %s.", - FftType_Name(fft_type).c_str(), - PrimitiveType_Name(in.element_type()).c_str()); + FftType_Name(fft_type), + PrimitiveType_Name(in.element_type())); } RET_CHECK_RANK(in); return in; case RFFT: { if (in.element_type() != F32) { return InvalidArgument("RFFT requires F32 input type, found %s.", - PrimitiveType_Name(in.element_type()).c_str()); + PrimitiveType_Name(in.element_type())); } RET_CHECK_RANK(in); for (int i = 0; i < fft_rank; i++) { @@ -1750,7 +1764,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, fft_length[i]) { return InvalidArgument( "RFFT requires innermost dimensions match fft_length but " - "dimension %lld is %lld and should be %lld.", + "dimension %d is %d and should be %d.", in.dimensions_size() - fft_rank + i, in.dimensions(in.dimensions_size() - fft_rank + i), fft_length[i]); @@ -1764,7 +1778,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, case IRFFT: { if (in.element_type() != C64) { return InvalidArgument("IRFFT requires C64 input type, found %s.", - PrimitiveType_Name(in.element_type()).c_str()); + PrimitiveType_Name(in.element_type())); } RET_CHECK_RANK(in); Shape result = ShapeUtil::ComplexComponentShape(in); @@ -1773,7 +1787,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, fft_length[i]) { return InvalidArgument( "IRFFT requires all but one innermost dimensions match " - "fft_length, but dimension %lld is %lld and should be %lld.", + "fft_length, but dimension %d is %d and should be %d.", in.dimensions_size() - fft_rank + i, in.dimensions(in.dimensions_size() - fft_rank + i), fft_length[i]); @@ -1783,7 +1797,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, fft_length[fft_rank - 1] / 2 + 1) { return InvalidArgument( "IRFFT requires innermost dimension matches fft_length/2+1, but " - "dimension %d is %lld and should be %lld.", + "dimension %d is %d and should be %d.", in.dimensions_size() - 1, in.dimensions(in.dimensions_size() - 1), fft_length[fft_rank - 1] / 2 + 1); } @@ -1798,7 +1812,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } /* static */ StatusOr ShapeInference::InferCrossReplicaSumShape( - tensorflow::gtl::ArraySlice operand_shapes) { + absl::Span operand_shapes) { for (const Shape* operand_shape : operand_shapes) { TF_RETURN_IF_ERROR( ExpectArray(*operand_shape, "operand of cross replica sum")); @@ -1819,18 +1833,18 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, TF_RET_CHECK(split_count > 0); if (split_dimension >= ShapeUtil::Rank(shape) || split_dimension < 0) { return InvalidArgument( - "AllToAll split_dimension %lld is out-of-bounds in shape %s.", - split_dimension, ShapeUtil::HumanString(shape).c_str()); + "AllToAll split_dimension %d is out-of-bounds in shape %s.", + split_dimension, ShapeUtil::HumanString(shape)); } if (concat_dimension >= ShapeUtil::Rank(shape) || concat_dimension < 0) { return InvalidArgument( - "AllToAll concat_dimension %lld is out-of-bounds in shape %s.", - concat_dimension, ShapeUtil::HumanString(shape).c_str()); + "AllToAll concat_dimension %d is out-of-bounds in shape %s.", + concat_dimension, ShapeUtil::HumanString(shape)); } if (shape.dimensions(split_dimension) % split_count != 0) { return InvalidArgument( - "AllToAll split dimension size %lld must be dividable by split_count " - "%lld.", + "AllToAll split dimension size %d must be dividable by split_count " + "%d.", shape.dimensions(split_dimension), split_count); } std::vector new_dimensions(shape.dimensions().begin(), @@ -1841,7 +1855,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } /* static */ StatusOr ShapeInference::InferAllToAllTupleShape( - tensorflow::gtl::ArraySlice operand_shapes) { + absl::Span operand_shapes) { // An Alltoall HLO instruction receives N operands (with the same shape) and // returns a tuple that contains N array shapes. TF_RET_CHECK(!operand_shapes.empty()); @@ -1850,17 +1864,23 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument( "HLO all-to-all has operands with different shapes: the 0th " "operand shape %s, but the %dth operand has shape %s.", - ShapeUtil::HumanString(*operand_shapes[0]).c_str(), i, - ShapeUtil::HumanString(*operand_shapes[i]).c_str()); + ShapeUtil::HumanString(*operand_shapes[0]), i, + ShapeUtil::HumanString(*operand_shapes[i])); } } return InferVariadicOpShape(HloOpcode::kTuple, operand_shapes); } +/* static */ StatusOr ShapeInference::InferCollectivePermuteShape( + const Shape& shape) { + TF_RET_CHECK(ShapeUtil::IsArray(shape)); + return shape; +} + /* static */ StatusOr ShapeInference::InferReduceShape( - tensorflow::gtl::ArraySlice arg_shapes, - tensorflow::gtl::ArraySlice dimensions_to_reduce, + absl::Span arg_shapes, + absl::Span dimensions_to_reduce, const ProgramShape& to_apply) { if (arg_shapes.empty()) { return InvalidArgument("Reduce must have at least 2 arguments, has 0"); @@ -1872,17 +1892,16 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } int64 num_reduced_args = arg_shapes.size() / 2; - tensorflow::gtl::ArraySlice reduced_args(arg_shapes, 0, - num_reduced_args); + auto reduced_args = arg_shapes.subspan(0, num_reduced_args); // Check that all of the reduced tensors have the same dimensions. The element // types may be different. for (int64 i = 1; i < num_reduced_args; ++i) { if (!ShapeUtil::SameDimensions(*reduced_args[0], *reduced_args[i])) { return InvalidArgument( "All reduced tensors must have the sime dimension. Tensor 0 has " - "shape %s, Tensor %lld has shape %s", - ShapeUtil::HumanString(*reduced_args[0]).c_str(), i, - ShapeUtil::HumanString(*reduced_args[i]).c_str()); + "shape %s, Tensor %d has shape %s", + ShapeUtil::HumanString(*reduced_args[0]), i, + ShapeUtil::HumanString(*reduced_args[i])); } } @@ -1892,14 +1911,12 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, const Shape& arg = *reduced_args[0]; for (int64 dimension : dimensions_to_reduce) { if (dimension >= ShapeUtil::Rank(arg) || dimension < 0) { - return InvalidArgument( - "Reducing out-of-bounds dimension %lld in shape %s.", dimension, - ShapeUtil::HumanString(arg).c_str()); + return InvalidArgument("Reducing out-of-bounds dimension %d in shape %s.", + dimension, ShapeUtil::HumanString(arg)); } } - tensorflow::gtl::ArraySlice init_values( - arg_shapes, num_reduced_args, arg_shapes.size()); + auto init_values = arg_shapes.subspan(num_reduced_args, arg_shapes.size()); std::vector element_types; for (const Shape* arg : reduced_args) { element_types.push_back(arg->element_type()); @@ -1967,16 +1984,16 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument( "Select function's first parameter shape currently must " "match the operand element shape, but got %s vs %s.", - ShapeUtil::HumanString(select_shape.parameters(0)).c_str(), - ShapeUtil::HumanString(operand_element_shape).c_str()); + ShapeUtil::HumanString(select_shape.parameters(0)), + ShapeUtil::HumanString(operand_element_shape)); } if (!ShapeUtil::CompatibleIgnoringFpPrecision(operand_element_shape, select_shape.parameters(1))) { return InvalidArgument( "Select function's second parameter shape currently must " "match the operand element shape, but got %s vs %s.", - ShapeUtil::HumanString(select_shape.parameters(1)).c_str(), - ShapeUtil::HumanString(operand_element_shape).c_str()); + ShapeUtil::HumanString(select_shape.parameters(1)), + ShapeUtil::HumanString(operand_element_shape)); } // Check if the scatter function has a proper shape as a reduction. @@ -1994,43 +2011,40 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument( "Source shape does not match the shape of window-reduced operand: " "source(%s), window-reduced operand(%s).", - ShapeUtil::HumanString(source_shape).c_str(), - ShapeUtil::HumanString(window_result_shape).c_str()); + ShapeUtil::HumanString(source_shape), + ShapeUtil::HumanString(window_result_shape)); } return operand_shape; } /* static */ StatusOr ShapeInference::InferSliceShape( - const Shape& arg, tensorflow::gtl::ArraySlice starts, - tensorflow::gtl::ArraySlice limits, - tensorflow::gtl::ArraySlice strides) { + const Shape& arg, absl::Span starts, + absl::Span limits, absl::Span strides) { auto error = [&](const string& message) { return InvalidArgument( "%s in slice operation; argument shape: %s; starts: {%s}; limits: " "{%s}; strides: {%s}.", - message.c_str(), ShapeUtil::HumanString(arg).c_str(), - StrJoin(starts, ",").c_str(), StrJoin(limits, ",").c_str(), - StrJoin(strides, ",").c_str()); + message, ShapeUtil::HumanString(arg), StrJoin(starts, ","), + StrJoin(limits, ","), StrJoin(strides, ",")); }; TF_RETURN_IF_ERROR(ExpectArray(arg, "operand of slice")); - VLOG(2) << tensorflow::strings::Printf( - "slicing shape %s starts={%s} limits={%s}", - ShapeUtil::HumanString(arg).c_str(), StrJoin(starts, ", ").c_str(), - StrJoin(limits, ", ").c_str()); + VLOG(2) << StrFormat("slicing shape %s starts={%s} limits={%s}", + ShapeUtil::HumanString(arg), StrJoin(starts, ", "), + StrJoin(limits, ", ")); if (starts.size() != limits.size()) { - return error(Printf("slice start and limit sizes differ: %zu vs %zu", - starts.size(), limits.size())); + return error(StrFormat("slice start and limit sizes differ: %u vs %u", + starts.size(), limits.size())); } if (starts.size() != strides.size()) { - return error(Printf("slice start and strides sizes differ: %zu vs %zu", - starts.size(), strides.size())); + return error(StrFormat("slice start and strides sizes differ: %u vs %u", + starts.size(), strides.size())); } if (starts.size() != ShapeUtil::Rank(arg)) { return InvalidArgument( - "Slice index count does not match argument rank: %zu vs %lld.", + "Slice index count does not match argument rank: %u vs %d.", starts.size(), ShapeUtil::Rank(arg)); } @@ -2040,27 +2054,24 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, int64 limit_index = limits[dimension]; int64 stride = strides[dimension]; if (start_index < 0) { - return InvalidArgument("Negative start index to slice: %lld.", - start_index); + return InvalidArgument("Negative start index to slice: %d.", start_index); } if (limit_index > arg.dimensions(dimension)) { return error( - Printf("limit index (%lld) must be less than or equal to dimension " - "size (%lld)", - limit_index, arg.dimensions(dimension))); - } - VLOG(2) << tensorflow::strings::Printf("starts[%lld] = %lld", dimension, - start_index); - VLOG(2) << tensorflow::strings::Printf("limits[%lld] = %lld", dimension, - limit_index); + StrFormat("limit index (%d) must be less than or equal to dimension " + "size (%d)", + limit_index, arg.dimensions(dimension))); + } + VLOG(2) << StrFormat("starts[%d] = %d", dimension, start_index); + VLOG(2) << StrFormat("limits[%d] = %d", dimension, limit_index); if (start_index > limit_index) { return error( - Printf("limit index (%lld) must be greater or equal to " - "start index (%lld) in slice with positive stride", - limit_index, start_index)); + StrFormat("limit index (%d) must be greater or equal to " + "start index (%d) in slice with positive stride", + limit_index, start_index)); } if (stride <= 0) { - return InvalidArgument("Stride (%lld) must be positive.", stride); + return InvalidArgument("Stride (%d) must be positive.", stride); } sizes.push_back((limit_index - start_index + stride - 1) / stride); } @@ -2070,20 +2081,19 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, /* static */ StatusOr ShapeInference::InferDynamicSliceShape( const Shape& operand_shape, const Shape& start_indices_shape, - tensorflow::gtl::ArraySlice slice_sizes) { + absl::Span slice_sizes) { TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of dynamic slice")); TF_RETURN_IF_ERROR( ExpectArray(start_indices_shape, "start indices of dynamic slice")); - VLOG(2) << tensorflow::strings::Printf( + VLOG(2) << StrFormat( "slicing shape %s at dynamic start_indices %s with slice_sizes={%s}", - ShapeUtil::HumanString(operand_shape).c_str(), - ShapeUtil::HumanString(start_indices_shape).c_str(), - StrJoin(slice_sizes, ", ").c_str()); + ShapeUtil::HumanString(operand_shape), + ShapeUtil::HumanString(start_indices_shape), StrJoin(slice_sizes, ", ")); if (ShapeUtil::Rank(start_indices_shape) != 1) { return InvalidArgument( - "Dynamic slice start indices of rank %lld must be rank1.", + "Dynamic slice start indices of rank %d must be rank1.", ShapeUtil::Rank(start_indices_shape)); } @@ -2095,16 +2105,15 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, const int64 start_num_dims = start_indices_shape.dimensions(0); if (ShapeUtil::Rank(operand_shape) != start_num_dims) { return InvalidArgument( - "Dynamic slice start number of dimensions %lld (%s) must match rank " - "%lld of slice input (%s).", - start_num_dims, ShapeUtil::HumanString(start_indices_shape).c_str(), - ShapeUtil::Rank(operand_shape), - ShapeUtil::HumanString(operand_shape).c_str()); + "Dynamic slice start number of dimensions %d (%s) must match rank " + "%d of slice input (%s).", + start_num_dims, ShapeUtil::HumanString(start_indices_shape), + ShapeUtil::Rank(operand_shape), ShapeUtil::HumanString(operand_shape)); } if (slice_sizes.size() != ShapeUtil::Rank(operand_shape)) { return InvalidArgument( - "Dynamic slice index count does not match argument rank: %zu vs %lld.", + "Dynamic slice index count does not match argument rank: %u vs %d.", slice_sizes.size(), ShapeUtil::Rank(operand_shape)); } @@ -2112,16 +2121,15 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, const int64 input_dim_size = operand_shape.dimensions(dim); const int64 slice_dim_size = slice_sizes[dim]; if (slice_dim_size < 0) { - return InvalidArgument("Negative size index to dynamic slice: %lld.", + return InvalidArgument("Negative size index to dynamic slice: %d.", slice_dim_size); } if (slice_dim_size > input_dim_size) { return InvalidArgument( - "Slice dim size %lld greater than dynamic slice dimension: %lld.", + "Slice dim size %d greater than dynamic slice dimension: %d.", slice_dim_size, input_dim_size); } - VLOG(2) << tensorflow::strings::Printf("slice_sizes[%lld] = %lld", dim, - slice_dim_size); + VLOG(2) << StrFormat("slice_sizes[%d] = %d", dim, slice_dim_size); } return ShapeUtil::MakeShape(operand_shape.element_type(), slice_sizes); @@ -2137,16 +2145,16 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, TF_RETURN_IF_ERROR(ExpectArray(start_indices_shape, "start indices of dynamic update slice")); - VLOG(2) << tensorflow::strings::Printf( + VLOG(2) << StrFormat( "updating slice of shape %s at dynamic start_indices %s with update " "shape %s", - ShapeUtil::HumanString(operand_shape).c_str(), - ShapeUtil::HumanString(start_indices_shape).c_str(), - ShapeUtil::HumanString(update_shape).c_str()); + ShapeUtil::HumanString(operand_shape), + ShapeUtil::HumanString(start_indices_shape), + ShapeUtil::HumanString(update_shape)); if (ShapeUtil::Rank(start_indices_shape) != 1) { return InvalidArgument( - "Dynamic update slice start indices of rank %lld must be rank1.", + "Dynamic update slice start indices of rank %d must be rank1.", ShapeUtil::Rank(start_indices_shape)); } @@ -2158,17 +2166,16 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, const int64 start_num_dims = start_indices_shape.dimensions(0); if (ShapeUtil::Rank(operand_shape) != start_num_dims) { return InvalidArgument( - "Dynamic update slice start number of dimensions %lld (%s) must match " - "rank %lld of slice input (%s).", - start_num_dims, ShapeUtil::HumanString(start_indices_shape).c_str(), - ShapeUtil::Rank(operand_shape), - ShapeUtil::HumanString(operand_shape).c_str()); + "Dynamic update slice start number of dimensions %d (%s) must match " + "rank %d of slice input (%s).", + start_num_dims, ShapeUtil::HumanString(start_indices_shape), + ShapeUtil::Rank(operand_shape), ShapeUtil::HumanString(operand_shape)); } if (ShapeUtil::Rank(update_shape) != ShapeUtil::Rank(operand_shape)) { return InvalidArgument( "Dynamic update slice update rank does not match argument rank: " - "%lld vs %lld.", + "%d vs %d.", ShapeUtil::Rank(update_shape), ShapeUtil::Rank(operand_shape)); } @@ -2177,8 +2184,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument( "Dynamic update slice update element type does not match argument. " "operand.element_type: %s vs update.element_type: %s.", - PrimitiveType_Name(operand_shape.element_type()).c_str(), - PrimitiveType_Name(update_shape.element_type()).c_str()); + PrimitiveType_Name(operand_shape.element_type()), + PrimitiveType_Name(update_shape.element_type())); } for (int64 dim = 0; dim < ShapeUtil::Rank(operand_shape); ++dim) { @@ -2186,23 +2193,22 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, const int64 update_dim_size = update_shape.dimensions(dim); if (update_dim_size < 0) { return InvalidArgument( - "Size index %lld to dynamic update slice must be >= 0.", + "Size index %d to dynamic update slice must be >= 0.", update_dim_size); } if (update_dim_size > input_dim_size) { return InvalidArgument( - "Update dim size %lld greater than dynamic slice dimension: %lld.", + "Update dim size %d greater than dynamic slice dimension: %d.", update_dim_size, input_dim_size); } - VLOG(2) << tensorflow::strings::Printf("update_sizes[%lld] = %lld", dim, - update_dim_size); + VLOG(2) << StrFormat("update_sizes[%d] = %d", dim, update_dim_size); } return operand_shape; } /*static */ StatusOr ShapeInference::InferReverseShape( - const Shape& operand_shape, tensorflow::gtl::ArraySlice dimensions) { + const Shape& operand_shape, absl::Span dimensions) { TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of reverse")); if (!AllUnique(dimensions)) { return InvalidArgument("a dimension number is duplicated in reverse"); @@ -2210,8 +2216,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, for (int64 dimension : dimensions) { if (dimension >= ShapeUtil::Rank(operand_shape) || dimension < 0) { return InvalidArgument( - "One of the reverse dimensions (%lld) is out-of-bounds in shape %s.", - dimension, ShapeUtil::HumanString(operand_shape).c_str()); + "One of the reverse dimensions (%d) is out-of-bounds in shape %s.", + dimension, ShapeUtil::HumanString(operand_shape)); } } return operand_shape; @@ -2222,14 +2228,14 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, if (!ShapeUtil::IsTuple(arg)) { return InvalidArgument( "Cannot infer shape: attempting to index into non-tuple: %s.", - ShapeUtil::HumanString(arg).c_str()); + ShapeUtil::HumanString(arg)); } if (index >= arg.tuple_shapes_size()) { return InvalidArgument( - "Cannot infer shape: attempt to index out of tuple bounds: %lld " + "Cannot infer shape: attempt to index out of tuple bounds: %d " ">= %d in shape %s.", - index, arg.tuple_shapes_size(), ShapeUtil::HumanString(arg).c_str()); + index, arg.tuple_shapes_size(), ShapeUtil::HumanString(arg)); } return arg.tuple_shapes(index); @@ -2249,17 +2255,15 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } auto shape_string = [&]() { - return tensorflow::strings::Printf( - "Condition: %s; body: %s; init: %s.", - ShapeUtil::HumanString(condition).c_str(), - ShapeUtil::HumanString(body).c_str(), - ShapeUtil::HumanString(init).c_str()); + return StrFormat( + "Condition: %s; body: %s; init: %s.", ShapeUtil::HumanString(condition), + ShapeUtil::HumanString(body), ShapeUtil::HumanString(init)); }; // Check the shapes of computation parameters and return types. if (!ShapeUtil::ShapeIs(condition.result(), PRED, {})) { return InvalidArgument("Condition must return a boolean; got %s.", - shape_string().c_str()); + shape_string()); } if (!ShapeUtil::Compatible(body.result(), condition.parameters(0)) || !ShapeUtil::Compatible(body.result(), body.parameters(0)) || @@ -2267,7 +2271,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument( "The parameter of condition and body, the result of the body, and init " "must all have the same shape; got %s.", - shape_string().c_str()); + shape_string()); } return init; @@ -2279,7 +2283,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, const ProgramShape& false_computation) { if (!ShapeUtil::ShapeIs(predicate, PRED, {})) { return InvalidArgument("Predicate must be a boolean; got %s.", - ShapeUtil::HumanString(predicate).c_str()); + ShapeUtil::HumanString(predicate)); } if (true_computation.parameters_size() != 1) { @@ -2288,15 +2292,14 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } if (!ShapeUtil::Compatible(true_computation.parameters(0), true_operand)) { auto true_shape_string = [&]() { - return tensorflow::strings::Printf( - "true_operand: %s; true_computation: %s", - ShapeUtil::HumanString(true_operand).c_str(), - ShapeUtil::HumanString(true_computation).c_str()); + return StrFormat("true_operand: %s; true_computation: %s", + ShapeUtil::HumanString(true_operand), + ShapeUtil::HumanString(true_computation)); }; return InvalidArgument( "true_operand must match the shape of the only parameter of " "true_computation: got %s.", - true_shape_string().c_str()); + true_shape_string()); } if (false_computation.parameters_size() != 1) { @@ -2305,38 +2308,37 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } if (!ShapeUtil::Compatible(false_computation.parameters(0), false_operand)) { auto false_shape_string = [&]() { - return tensorflow::strings::Printf( - "false_operand: %s; false_computation: %s", - ShapeUtil::HumanString(false_operand).c_str(), - ShapeUtil::HumanString(false_computation).c_str()); + return StrFormat("false_operand: %s; false_computation: %s", + ShapeUtil::HumanString(false_operand), + ShapeUtil::HumanString(false_computation)); }; return InvalidArgument( "false_operand must match the shape of the only parameter of " "false_computation: got %s.", - false_shape_string().c_str()); + false_shape_string()); } if (!ShapeUtil::Compatible(true_computation.result(), false_computation.result())) { auto shape_string = [&]() { - return tensorflow::strings::Printf( + return StrFormat( "true_computation result: %s; false_computation result: %s.", - ShapeUtil::HumanString(true_computation.result()).c_str(), - ShapeUtil::HumanString(false_computation.result()).c_str()); + ShapeUtil::HumanString(true_computation.result()), + ShapeUtil::HumanString(false_computation.result())); }; return InvalidArgument( "the result of true_computation and false_computation must have the " "same shape: got %s.", - shape_string().c_str()); + shape_string()); } return true_computation.result(); } /* static */ StatusOr ShapeInference::InferBroadcastShape( - const Shape& operand, tensorflow::gtl::ArraySlice broadcast_sizes) { + const Shape& operand, absl::Span broadcast_sizes) { TF_RETURN_IF_ERROR(ExpectArray(operand, "operand of broadcast")); for (int64 size : broadcast_sizes) { if (size < 0) { - return InvalidArgument("Broadcast with negative dimension size %lld.", + return InvalidArgument("Broadcast with negative dimension size %d.", size); } } @@ -2350,8 +2352,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } /* static */ StatusOr ShapeInference::InferReshapeShape( - const Shape& operand, tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice new_sizes) { + const Shape& operand, absl::Span dimensions, + absl::Span new_sizes) { TF_RETURN_IF_ERROR(ExpectArray(operand, "reshape")); Shape inferred_shape = @@ -2361,11 +2363,11 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, if (ShapeUtil::ElementsIn(operand) != ShapeUtil::ElementsIn(inferred_shape)) { return InvalidArgument( - "Reshape operation has mismatched element counts: from=%lld (%s) " - "to=%lld (%s).", - ShapeUtil::ElementsIn(operand), ShapeUtil::HumanString(operand).c_str(), + "Reshape operation has mismatched element counts: from=%d (%s) " + "to=%d (%s).", + ShapeUtil::ElementsIn(operand), ShapeUtil::HumanString(operand), ShapeUtil::ElementsIn(inferred_shape), - ShapeUtil::HumanString(inferred_shape).c_str()); + ShapeUtil::HumanString(inferred_shape)); } std::vector indices(ShapeUtil::Rank(operand)); @@ -2376,15 +2378,14 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument( "Reshape dimensions [%s] are not a permutation of the operand " "dimensions (operand shape is %s).", - StrJoin(dimensions, ",").c_str(), - ShapeUtil::HumanString(operand).c_str()); + StrJoin(dimensions, ","), ShapeUtil::HumanString(operand)); } return inferred_shape; } /* static */ StatusOr ShapeInference::InferTransposeShape( - const Shape& operand, tensorflow::gtl::ArraySlice dimensions) { + const Shape& operand, absl::Span dimensions) { TF_RETURN_IF_ERROR(ExpectArray(operand, "transpose")); std::vector indices(ShapeUtil::Rank(operand)); @@ -2393,7 +2394,9 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, !std::is_permutation(dimensions.begin(), dimensions.end(), indices.begin())) { return InvalidArgument( - "Transpose dimensions not a permutation of the operand dimensions."); + "Transpose dimensions [%s] are not a permutation of the operand " + "dimensions (operand shape is %s).", + StrJoin(dimensions, ","), ShapeUtil::HumanString(operand)); } // Permute(dimensions,input) computes output[dimensions[i]]=input[i]. However, @@ -2412,9 +2415,9 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(min, operand) || !ShapeUtil::SameElementTypeIgnoringFpPrecision(max, operand)) { return InvalidArgument("Clamp with different operand types: %s, %s, %s.", - ShapeUtil::HumanString(min).c_str(), - ShapeUtil::HumanString(operand).c_str(), - ShapeUtil::HumanString(max).c_str()); + ShapeUtil::HumanString(min), + ShapeUtil::HumanString(operand), + ShapeUtil::HumanString(max)); } if (((ShapeUtil::CompatibleIgnoringFpPrecision(min, operand) || ShapeUtil::IsScalar(min)) && @@ -2431,9 +2434,9 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return ShapeUtil::ChangeElementType(min, operand.element_type()); } } - return Unimplemented( - "%s, %s %s is not implemented.", min.ShortDebugString().c_str(), - max.ShortDebugString().c_str(), operand.ShortDebugString().c_str()); + return Unimplemented("%s, %s %s is not implemented.", + min.ShortDebugString(), max.ShortDebugString(), + operand.ShortDebugString()); } // TODO(b/36794510): Make broadcast semantics more consistent, by supporting @@ -2444,13 +2447,12 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, if (!ShapeUtil::CompatibleIgnoringFpPrecision(on_true, on_false)) { return InvalidArgument( "Operands to select must be the same shape; got %s and %s.", - ShapeUtil::HumanString(on_true).c_str(), - ShapeUtil::HumanString(on_false).c_str()); + ShapeUtil::HumanString(on_true), ShapeUtil::HumanString(on_false)); } if (pred.element_type() != PRED) { return InvalidArgument( "Select's pred operand must have PRED element type; got %s.", - ShapeUtil::HumanString(pred).c_str()); + ShapeUtil::HumanString(pred)); } if (ShapeUtil::CompatibleIgnoringElementType(pred, on_true) || ShapeUtil::IsScalar(pred)) { @@ -2463,7 +2465,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument( "Select operation with non-scalar predicate with dimensionality " " different from the other operands: %s.", - ShapeUtil::HumanString(pred).c_str()); + ShapeUtil::HumanString(pred)); } } @@ -2474,25 +2476,23 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, if (!ShapeUtil::Compatible(on_true, on_false)) { return InvalidArgument( "Operands to tuple-select must be the same shape; got %s and %s.", - ShapeUtil::HumanString(on_true).c_str(), - ShapeUtil::HumanString(on_false).c_str()); + ShapeUtil::HumanString(on_true), ShapeUtil::HumanString(on_false)); } if (pred.element_type() != PRED) { return InvalidArgument( "TupleSelect's pred operand must have PRED element type; got %s.", - ShapeUtil::HumanString(pred).c_str()); + ShapeUtil::HumanString(pred)); } if (!ShapeUtil::IsScalar(pred)) { return InvalidArgument( "TupleSelect operation with non-scalar predicate: %s.", - ShapeUtil::HumanString(pred).c_str()); + ShapeUtil::HumanString(pred)); } return on_true; } /* static */ StatusOr ShapeInference::InferCallShape( - tensorflow::gtl::ArraySlice arg_shapes, - const ProgramShape& to_apply) { + absl::Span arg_shapes, const ProgramShape& to_apply) { // The applied function's arity equals the number of arguments. if (arg_shapes.size() != to_apply.parameters_size()) { string computation_signature = ShapeUtil::HumanString(to_apply); @@ -2502,10 +2502,10 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, }); return InvalidArgument( "Call applied function arity must match number of arguments; got: " - "arity: %d, arguments: %zu; computation signature: %s; argument " + "arity: %d, arguments: %u; computation signature: %s; argument " "shapes: [%s].", - to_apply.parameters_size(), arg_shapes.size(), - computation_signature.c_str(), argument_shapes.c_str()); + to_apply.parameters_size(), arg_shapes.size(), computation_signature, + argument_shapes); } // All arguments must be compatible with the program shape. @@ -2516,8 +2516,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument( "Call parameter must match argument; got parameter %d shape: %s, " "argument shape: %s.", - i, ShapeUtil::HumanString(param_shape).c_str(), - ShapeUtil::HumanString(arg_shape).c_str()); + i, ShapeUtil::HumanString(param_shape), + ShapeUtil::HumanString(arg_shape)); } } @@ -2525,20 +2525,19 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } static Status ValidateGatherDimensionNumbers( - const Shape& input_shape, - tensorflow::gtl::ArraySlice start_indices_shape, + const Shape& input_shape, absl::Span start_indices_shape, const GatherDimensionNumbers& dim_numbers) { if (!absl::c_is_sorted(dim_numbers.offset_dims())) { return InvalidArgument( "Output window dimensions in gather op must be ascending; got: %s.", - StrJoin(dim_numbers.offset_dims(), ", ").c_str()); + StrJoin(dim_numbers.offset_dims(), ", ")); } if (absl::c_adjacent_find(dim_numbers.offset_dims()) != dim_numbers.offset_dims().end()) { return InvalidArgument( "Output window dimensions in gather op must not repeat; got: %s.", - StrJoin(dim_numbers.offset_dims(), ", ").c_str()); + StrJoin(dim_numbers.offset_dims(), ", ")); } const int64 output_offset_dim_count = dim_numbers.offset_dims_size(); @@ -2549,9 +2548,9 @@ static Status ValidateGatherDimensionNumbers( int64 offset_dim = dim_numbers.offset_dims(i); if (offset_dim < 0 || offset_dim >= output_shape_rank) { return InvalidArgument( - "Offset dimension %d in gather op is out of bounds; got %lld, but " + "Offset dimension %d in gather op is out of bounds; got %d, but " "should " - "have been in [0,%lld).", + "have been in [0,%d).", i, offset_dim, output_shape_rank); } } @@ -2560,8 +2559,8 @@ static Status ValidateGatherDimensionNumbers( start_indices_shape[dim_numbers.index_vector_dim()]) { return InvalidArgument( "Gather op has %d elements in start_index_map and the " - "bound of dimension index_vector_dim=%lld of start_indices is " - "%lld. These two numbers must be equal.", + "bound of dimension index_vector_dim=%d of start_indices is " + "%d. These two numbers must be equal.", dim_numbers.start_index_map_size(), dim_numbers.index_vector_dim(), start_indices_shape[dim_numbers.index_vector_dim()]); } @@ -2571,7 +2570,7 @@ static Status ValidateGatherDimensionNumbers( if (operand_dim_for_start_index_i < 0 || operand_dim_for_start_index_i >= input_shape.dimensions_size()) { return InvalidArgument( - "Invalid start_index_map; domain is [0, %d), got: %d->%lld.", + "Invalid start_index_map; domain is [0, %d), got: %d->%d.", input_shape.dimensions_size(), i, operand_dim_for_start_index_i); } } @@ -2587,14 +2586,14 @@ static Status ValidateGatherDimensionNumbers( return InvalidArgument( "Repeated dimensions are not allowed in start_index_map; " "got: %s.", - StrJoin(dim_numbers.start_index_map(), ", ").c_str()); + StrJoin(dim_numbers.start_index_map(), ", ")); } for (int64 collapsed_dim : dim_numbers.collapsed_slice_dims()) { if (collapsed_dim < 0 || collapsed_dim >= input_shape.dimensions_size()) { return InvalidArgument( "Invalid collapsed_slice_dims set in gather op; valid range is [0, " - "%d), got: %lld.", + "%d), got: %d.", input_shape.dimensions_size(), collapsed_dim); } } @@ -2602,7 +2601,7 @@ static Status ValidateGatherDimensionNumbers( if (!absl::c_is_sorted(dim_numbers.collapsed_slice_dims())) { return InvalidArgument( "collapsed_slice_dims in gather op must be sorted; got: %s", - StrJoin(dim_numbers.collapsed_slice_dims(), ", ").c_str()); + StrJoin(dim_numbers.collapsed_slice_dims(), ", ")); } if (absl::c_adjacent_find(dim_numbers.collapsed_slice_dims()) != @@ -2610,7 +2609,7 @@ static Status ValidateGatherDimensionNumbers( return InvalidArgument( "Repeated dimensions not allowed in collapsed_slice_dims in gather op; " "got: %s.", - StrJoin(dim_numbers.collapsed_slice_dims(), ", ").c_str()); + StrJoin(dim_numbers.collapsed_slice_dims(), ", ")); } return Status::OK(); @@ -2619,7 +2618,7 @@ static Status ValidateGatherDimensionNumbers( /*static*/ StatusOr ShapeInference::InferGatherShape( const Shape& input_shape, const Shape& start_indices_shape, const GatherDimensionNumbers& gather_dim_numbers, - tensorflow::gtl::ArraySlice slice_sizes) { + absl::Span slice_sizes) { TF_RETURN_IF_ERROR( ExpectArray(input_shape, "input tensor operand gather op")); TF_RETURN_IF_ERROR( @@ -2628,7 +2627,7 @@ static Status ValidateGatherDimensionNumbers( if (!ShapeUtil::ElementIsIntegral(start_indices_shape)) { return InvalidArgument( "Gather indices parameter must be an integral tensor; got %s.", - ShapeUtil::HumanString(start_indices_shape).c_str()); + ShapeUtil::HumanString(start_indices_shape)); } // We implicitly reshape gather indices of shape P[A,B,C] to P[A,B,C,1] if @@ -2641,7 +2640,7 @@ static Status ValidateGatherDimensionNumbers( return InvalidArgument( "Gather index leaf dimension must be within [0, rank(start_indices) + " "1). rank(start_indices) is %d and gather index leaf dimension is " - "%lld.", + "%d.", start_indices_shape.dimensions_size(), gather_dim_numbers.index_vector_dim()); } @@ -2672,9 +2671,8 @@ static Status ValidateGatherDimensionNumbers( "All components of the offset index in a gather op must either be a " "offset dimension or explicitly collapsed; got len(slice_sizes)=%lu, " "output_slice_sizes=%s, collapsed_slice_dims=%s.", - slice_sizes.size(), - StrJoin(gather_dim_numbers.offset_dims(), ",").c_str(), - StrJoin(gather_dim_numbers.collapsed_slice_dims(), ",").c_str()); + slice_sizes.size(), StrJoin(gather_dim_numbers.offset_dims(), ","), + StrJoin(gather_dim_numbers.collapsed_slice_dims(), ",")); } for (int i = 0; i < slice_sizes.size(); i++) { @@ -2683,7 +2681,7 @@ static Status ValidateGatherDimensionNumbers( if (slice_size < 0 || slice_size > corresponding_input_size) { return InvalidArgument( "Slice size at index %d in gather op is out of range, must be " - "within [0, %lld), got %lld.", + "within [0, %d), got %d.", i, corresponding_input_size + 1, slice_size); } } @@ -2692,7 +2690,7 @@ static Status ValidateGatherDimensionNumbers( if (slice_sizes[gather_dim_numbers.collapsed_slice_dims(i)] != 1) { return InvalidArgument( "Gather op can only collapse slice dims with bound 1, but bound is " - "%lld for index %lld at position %d.", + "%d for index %d at position %d.", slice_sizes[gather_dim_numbers.collapsed_slice_dims(i)], gather_dim_numbers.collapsed_slice_dims(i), i); } @@ -2730,27 +2728,26 @@ static Status ValidateGatherDimensionNumbers( namespace { Status ValidateScatterDimensionNumbers( - const Shape& operand_shape, - tensorflow::gtl::ArraySlice scatter_indices_shape, + const Shape& operand_shape, absl::Span scatter_indices_shape, const Shape& updates_shape, const ScatterDimensionNumbers& dim_numbers) { // Validate update_window_dims in ScatterDimensionNumbers. if (!absl::c_is_sorted(dim_numbers.update_window_dims())) { return InvalidArgument( "update_window_dims in scatter op must be sorted; got: %s.", - StrJoin(dim_numbers.update_window_dims(), ", ").c_str()); + StrJoin(dim_numbers.update_window_dims(), ", ")); } if (absl::c_adjacent_find(dim_numbers.update_window_dims()) != dim_numbers.update_window_dims().end()) { return InvalidArgument( "update_window_dims in scatter op must not repeat; got: %s.", - StrJoin(dim_numbers.update_window_dims(), ", ").c_str()); + StrJoin(dim_numbers.update_window_dims(), ", ")); } const int64 updates_rank = ShapeUtil::Rank(updates_shape); for (int64 window_dim : dim_numbers.update_window_dims()) { if (window_dim < 0 || window_dim >= updates_rank) { return InvalidArgument( "Invalid update_window_dims set in scatter op; valid range is [0, " - "%lld). got: %lld.", + "%d). got: %d.", updates_rank, window_dim); } } @@ -2759,19 +2756,19 @@ Status ValidateScatterDimensionNumbers( if (!absl::c_is_sorted(dim_numbers.inserted_window_dims())) { return InvalidArgument( "inserted_window_dims in scatter op must be sorted; got: %s.", - StrJoin(dim_numbers.inserted_window_dims(), ", ").c_str()); + StrJoin(dim_numbers.inserted_window_dims(), ", ")); } if (absl::c_adjacent_find(dim_numbers.inserted_window_dims()) != dim_numbers.inserted_window_dims().end()) { return InvalidArgument( "inserted_window_dims in scatter op must not repeat; got: %s.", - StrJoin(dim_numbers.inserted_window_dims(), ", ").c_str()); + StrJoin(dim_numbers.inserted_window_dims(), ", ")); } for (int64 inserted_dim : dim_numbers.inserted_window_dims()) { if (inserted_dim < 0 || inserted_dim >= operand_shape.dimensions_size()) { return InvalidArgument( "Invalid inserted_window_dims set in scatter op; valid range is [0, " - "%d), got: %lld.", + "%d), got: %d.", operand_shape.dimensions_size(), inserted_dim); } } @@ -2781,7 +2778,7 @@ Status ValidateScatterDimensionNumbers( scatter_indices_shape[dim_numbers.index_vector_dim()]) { return InvalidArgument( "Scatter op has %d elements in scatter_dims_to_operand_dims and the " - "bound of dimension index_vector_dim=%lld of scatter_indices is %lld. " + "bound of dimension index_vector_dim=%d of scatter_indices is %d. " "These two numbers must be equal.", dim_numbers.scatter_dims_to_operand_dims_size(), dim_numbers.index_vector_dim(), @@ -2794,7 +2791,7 @@ Status ValidateScatterDimensionNumbers( scatter_dim_to_operand_dim >= operand_shape.dimensions_size()) { return InvalidArgument( "Invalid scatter_dims_to_operand_dims mapping; domain is [0, %d), " - "got: %d->%lld.", + "got: %d->%d.", operand_shape.dimensions_size(), i, scatter_dim_to_operand_dim); } } @@ -2807,7 +2804,7 @@ Status ValidateScatterDimensionNumbers( return InvalidArgument( "Repeated dimensions not allowed in scatter_dims_to_operand_dims; " "got: %s.", - StrJoin(dim_numbers.scatter_dims_to_operand_dims(), ", ").c_str()); + StrJoin(dim_numbers.scatter_dims_to_operand_dims(), ", ")); } return Status::OK(); @@ -2828,7 +2825,7 @@ Status ValidateScatterDimensionNumbers( if (!ShapeUtil::ElementIsIntegral(scatter_indices_shape)) { return InvalidArgument( "Scatter indices parameter must be an integral tensor; got %s.", - ShapeUtil::HumanString(scatter_indices_shape).c_str()); + ShapeUtil::HumanString(scatter_indices_shape)); } if (scatter_indices_shape.dimensions_size() < @@ -2837,7 +2834,7 @@ Status ValidateScatterDimensionNumbers( return InvalidArgument( "Scatter index leaf dimension must be within [0, rank(scatter_indices)" " + 1). rank(scatter_indices) is %d and scatter index leaf dimension " - "is %lld.", + "is %d.", scatter_indices_shape.dimensions_size(), scatter_dim_numbers.index_vector_dim()); } @@ -2859,7 +2856,7 @@ Status ValidateScatterDimensionNumbers( int64 expected_updates_rank = expanded_scatter_indices_shape.size() - 1 + scatter_dim_numbers.update_window_dims_size(); if (ShapeUtil::Rank(updates_shape) != expected_updates_rank) { - return InvalidArgument("Updates tensor must be of rank %lld; got %lld.", + return InvalidArgument("Updates tensor must be of rank %d; got %d.", expected_updates_rank, ShapeUtil::Rank(updates_shape)); } @@ -2885,7 +2882,7 @@ Status ValidateScatterDimensionNumbers( return InvalidArgument( "Bounds of the window dimensions of updates must not exceed the " "bounds of the corresponding dimensions of operand. For dimension " - "%lld, updates bound is %lld, operand bound is %lld.", + "%d, updates bound is %d, operand bound is %d.", update_window_dim, updates_shape.dimensions(update_window_dim), max_update_slice_sizes[i]); } @@ -2906,8 +2903,8 @@ Status ValidateScatterDimensionNumbers( return InvalidArgument( "Bounds of the scatter dimensions of updates must be same as the " "bounds of the corresponding dimensions of scatter indices. For " - "scatter dimension %lld, updates bound is %lld, scatter_indices " - "bound is %lld.", + "scatter dimension %d, updates bound is %d, scatter_indices " + "bound is %d.", i, updates_shape.dimensions(i), expanded_scatter_indices_shape[scatter_dims_seen]); } diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h index 4974ac9916abaea25f8d455b24f7c0904277f5f7..96a0ee165d46753da4fef119e7072f66637bf2c4 100644 --- a/tensorflow/compiler/xla/service/shape_inference.h +++ b/tensorflow/compiler/xla/service/shape_inference.h @@ -21,12 +21,12 @@ limitations under the License. #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" @@ -55,7 +55,7 @@ class ShapeInference { // given input shapes. static StatusOr InferBinaryOpShape( HloOpcode opcode, const Shape& lhs, const Shape& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions); + absl::Span broadcast_dimensions); static StatusOr InferBinaryOpShape(HloOpcode opcode, const HloInstruction* lhs, const HloInstruction* rhs); @@ -73,18 +73,15 @@ class ShapeInference { // Infers the shape produced by applying the given variadic operation to the // given input operand shapes. static StatusOr InferVariadicOpShape( - HloOpcode opcode, - tensorflow::gtl::ArraySlice operand_shapes); + HloOpcode opcode, absl::Span operand_shapes); static StatusOr InferVariadicOpShape( - HloOpcode opcode, - tensorflow::gtl::ArraySlice operands); + HloOpcode opcode, absl::Span operands); // Infers the shape produced by applying the given mapping computation shape // to the given operand shapes. static StatusOr InferMapShape( - tensorflow::gtl::ArraySlice arg_shapes, - const ProgramShape& to_apply, - tensorflow::gtl::ArraySlice dimensions); + absl::Span arg_shapes, const ProgramShape& to_apply, + absl::Span dimensions); // Infers the shape produced by InferBatchNormTraining with the given // operands. @@ -111,19 +108,18 @@ class ShapeInference { // Infers the shape produced by applying the given convolutional // filter (rhs) to lhs in the way specified by the fields on window. static StatusOr InferConvolveShape( - const Shape& lhs, const Shape& rhs, const Window& window, - const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count = 1); + const Shape& lhs, const Shape& rhs, int64 feature_group_count, + const Window& window, + const ConvolutionDimensionNumbers& dimension_numbers); // Infers the shape produced by the given FFT type on the given operand. - static StatusOr InferFftShape( - const Shape& in, FftType fft_type, - tensorflow::gtl::ArraySlice fft_length); + static StatusOr InferFftShape(const Shape& in, FftType fft_type, + absl::Span fft_length); // Infers the shape produced by a cross replica sum with the given operand // shapes. static StatusOr InferCrossReplicaSumShape( - tensorflow::gtl::ArraySlice operand_shapes); + absl::Span operand_shapes); // Infers final shape of an Alltoall operation that is created by the xla // builder. @@ -134,7 +130,10 @@ class ShapeInference { // Infers the shape of an HLO all-to-all instruction. static StatusOr InferAllToAllTupleShape( - tensorflow::gtl::ArraySlice operand_shapes); + absl::Span operand_shapes); + + // Infers the shape of a collective permute operation. + static StatusOr InferCollectivePermuteShape(const Shape& shape); // Infers the shape produced by applying the given reduction computation // shape to the given input operand shape. @@ -143,8 +142,8 @@ class ShapeInference { // index as the leading parameter, and the program shape should match // accordingly (or an error will result). static StatusOr InferReduceShape( - tensorflow::gtl::ArraySlice arg_shapes, - tensorflow::gtl::ArraySlice dimensions_to_reduce, + absl::Span arg_shapes, + absl::Span dimensions_to_reduce, const ProgramShape& to_apply); // Infers the shape produced by applying the given computation to the operand @@ -162,24 +161,23 @@ class ShapeInference { // Infers the shape produced by a reverse operation that reverses the order // of the elements in the given dimensions. - static StatusOr InferReverseShape( - const Shape& operand_shape, - tensorflow::gtl::ArraySlice dimensions); + static StatusOr InferReverseShape(const Shape& operand_shape, + absl::Span dimensions); // Infers the shape produced by a slice operation spanning from the starts to // the limits in the original shape's dimensions. // // e.g. slice f32[32x32] 0:16 0:16 -> f32[16x16] - static StatusOr InferSliceShape( - const Shape& arg, tensorflow::gtl::ArraySlice starts, - tensorflow::gtl::ArraySlice limits, - tensorflow::gtl::ArraySlice strides); + static StatusOr InferSliceShape(const Shape& arg, + absl::Span starts, + absl::Span limits, + absl::Span strides); // Infers the shape produced by a dynamic slice operation of size specified // in 'slice_sizes', with dynamic start indices shape 'start_indices_shape'. static StatusOr InferDynamicSliceShape( const Shape& operand_shape, const Shape& start_indices_shape, - tensorflow::gtl::ArraySlice slice_sizes); + absl::Span slice_sizes); // Infers the shape produced by a dynamic update slice operation based // on the shape of operand and update. @@ -210,30 +208,30 @@ class ShapeInference { // Infers the shape produced by a broadcast operation. static StatusOr InferBroadcastShape( - const Shape& operand, tensorflow::gtl::ArraySlice broadcast_sizes); + const Shape& operand, absl::Span broadcast_sizes); // Infers the shape produced by a reshape operation from the element type of // its operand and the new dimension sizes specified. - static StatusOr InferReshapeShape( - const Shape& operand, tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice new_sizes); + static StatusOr InferReshapeShape(const Shape& operand, + absl::Span dimensions, + absl::Span new_sizes); // Infers the shape produced by a transpose operation from the element type of // its operand and its dimensions field. static StatusOr InferTransposeShape( - const Shape& operand, tensorflow::gtl::ArraySlice dimensions); + const Shape& operand, absl::Span dimensions); // Helper that infers the shape produced by performing a concatenate operation // with the given operand shapes. static StatusOr InferConcatOpShape( - tensorflow::gtl::ArraySlice arg_shapes, int64 dimension); + absl::Span arg_shapes, int64 dimension); // Infers the shape produced by a kAfterAll. Trivially this shape is always a // TOKEN shape. However, ShapeInference serves two purposes: inferring shapes // and checking operand shapes. This method verifies that the operand shapes // are all TOKENs. static StatusOr InferAfterAllShape( - tensorflow::gtl::ArraySlice arg_shapes); + absl::Span arg_shapes); // Helper that validates the given operand shape can be converted to the // target output_shape via a convert instruction -- the requirement is that @@ -263,8 +261,7 @@ class ShapeInference { // Helper that validates the given arg_shapes are compatible with the shape of // the to_apply parameters, and returns the to_apply result shape. static StatusOr InferCallShape( - tensorflow::gtl::ArraySlice arg_shapes, - const ProgramShape& to_apply); + absl::Span arg_shapes, const ProgramShape& to_apply); // Helper that infers the shape produced by performing a dot operation with // the given LHS and RHS shapes. @@ -278,7 +275,7 @@ class ShapeInference { static StatusOr InferGatherShape( const Shape& input_shape, const Shape& start_indices_shape, const GatherDimensionNumbers& gather_dim_numbers, - tensorflow::gtl::ArraySlice slice_sizes); + absl::Span slice_sizes); // Helper that validates the given input shape, scatter indices shape, updates // shape, and scatter dimension numbers that constitute a scatter operation, @@ -296,7 +293,7 @@ class ShapeInference { // even in the presence of broadcasting of one of the operands over the other. static StatusOr InferElementwiseBinaryOpShape( HloOpcode operation, const Shape& lhs, const Shape& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions); + absl::Span broadcast_dimensions); // Helper for inferring the shape of Clamp ops. static StatusOr InferClampShape(const Shape& min, const Shape& operand, @@ -324,7 +321,7 @@ class ShapeInference { // smaller_shape is broadcast to. static StatusOr InferInDimBroadcastShape( const Shape& smaller_shape, const Shape& larger_shape, - tensorflow::gtl::ArraySlice broadcast_dimensions); + absl::Span broadcast_dimensions); TF_DISALLOW_COPY_AND_ASSIGN(ShapeInference); }; diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc index 4ed8fc6b8654fb87701a629c1ded397fe23e52cd..7b65e8c1c9d2bc730c6c8550e9265b69fdde71cf 100644 --- a/tensorflow/compiler/xla/service/shape_inference_test.cc +++ b/tensorflow/compiler/xla/service/shape_inference_test.cc @@ -17,18 +17,17 @@ limitations under the License. #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" namespace xla { namespace { -using ::tensorflow::gtl::ArraySlice; using ::testing::ContainsRegex; using ::testing::HasSubstr; @@ -58,9 +57,9 @@ class ReduceShapeInferenceTest : public ShapeInferenceTest { // Helper that runs reduce shape inference with the input 'arg' and given // dimensions to reduce, and checks the inferred shape is as expected. The // element type here is hard-coded to F32. - void ExpectInferredReduceShape( - const Shape& expected_inferred_shape, const Shape& arg, - tensorflow::gtl::ArraySlice dimensions_to_reduce) { + void ExpectInferredReduceShape(const Shape& expected_inferred_shape, + const Shape& arg, + absl::Span dimensions_to_reduce) { ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_}, f32_); auto inferred_status = ShapeInference::InferReduceShape( {&arg, &f32_}, dimensions_to_reduce, to_apply); @@ -252,7 +251,7 @@ TEST_F(ShapeInferenceTest, ClampBadShapes) { TEST_F(ShapeInferenceTest, Complex) { auto complex_shape = [&](const Shape& lhs, const Shape& rhs, - const tensorflow::gtl::ArraySlice& bcast) { + const absl::Span& bcast) { return ShapeInference::InferBinaryOpShape(HloOpcode::kComplex, lhs, rhs, bcast); }; @@ -420,8 +419,8 @@ TEST_F(ShapeInferenceTest, Convolve) { dim1->set_padding_high(0); dim1->set_window_dilation(1); dim1->set_base_dilation(1); - auto inferred_status = - ShapeInference::InferConvolveShape(lhs_shape, rhs_shape, window, dnums); + auto inferred_status = ShapeInference::InferConvolveShape( + lhs_shape, rhs_shape, /*feature_group_count=*/1, window, dnums); ASSERT_IS_OK(inferred_status.status()); Shape inferred_shape = inferred_status.ValueOrDie(); ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 2, 3}), @@ -465,8 +464,8 @@ TEST_F(ShapeInferenceTest, ConvolveWithWindowDilation) { dim1->set_padding_high(1); dim1->set_window_dilation(2); dim1->set_base_dilation(1); - auto inferred_status = - ShapeInference::InferConvolveShape(lhs_shape, rhs_shape, window, dnums); + auto inferred_status = ShapeInference::InferConvolveShape( + lhs_shape, rhs_shape, /*feature_group_count=*/1, window, dnums); ASSERT_IS_OK(inferred_status.status()); Shape inferred_shape = inferred_status.ValueOrDie(); ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 31, 5}), @@ -510,8 +509,8 @@ TEST_F(ShapeInferenceTest, ConvolveWithBaseDilation) { dim1->set_padding_high(1); dim1->set_window_dilation(1); dim1->set_base_dilation(2); - auto inferred_status = - ShapeInference::InferConvolveShape(lhs_shape, rhs_shape, window, dnums); + auto inferred_status = ShapeInference::InferConvolveShape( + lhs_shape, rhs_shape, /*feature_group_count=*/1, window, dnums); ASSERT_IS_OK(inferred_status.status()); Shape inferred_shape = inferred_status.ValueOrDie(); ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 4, 9}), @@ -548,8 +547,8 @@ TEST_F(ShapeInferenceTest, ConvolveDimensionNumbersOverlapError) { dim1->set_stride(2); dim1->set_padding_low(1); dim1->set_padding_high(1); - auto inferred_status = - ShapeInference::InferConvolveShape(lhs_shape, rhs_shape, window, dnums); + auto inferred_status = ShapeInference::InferConvolveShape( + lhs_shape, rhs_shape, /*feature_group_count=*/1, window, dnums); ASSERT_FALSE(inferred_status.ok()); ASSERT_THAT(inferred_status.status().error_message(), HasSubstr("each dimension exactly once")); @@ -1619,13 +1618,37 @@ TEST_F(ShapeInferenceTest, BadSort) { auto values = ShapeUtil::MakeShape(F32, {5}); StatusOr statusor = ShapeInference::InferVariadicOpShape(HloOpcode::kSort, {&keys, &values}); - ASSERT_FALSE(statusor.ok()); + EXPECT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("dimensions must match")) + << statusor.status(); +} +TEST_F(ShapeInferenceTest, BadSortValuesMismatch) { + auto keys = ShapeUtil::MakeShape(F32, {4}); + auto values_good = ShapeUtil::MakeShape(F32, {4}); + auto values_bad = ShapeUtil::MakeShape(F32, {5}); + StatusOr statusor = ShapeInference::InferVariadicOpShape( + HloOpcode::kSort, {&keys, &values_good, &values_bad}); + EXPECT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), HasSubstr("dimensions must match")) << statusor.status(); } +TEST_F(ShapeInferenceTest, SortManyValues) { + auto keys = ShapeUtil::MakeShape(F32, {4}); + auto values_s32 = ShapeUtil::MakeShape(S32, {4}); + auto values_u32 = ShapeUtil::MakeShape(U32, {4}); + StatusOr statusor = ShapeInference::InferVariadicOpShape( + HloOpcode::kSort, {&keys, &values_s32, &values_u32}); + EXPECT_IS_OK(statusor); + Shape inferred_shape = statusor.ValueOrDie(); + EXPECT_TRUE(ShapeUtil::Compatible( + inferred_shape, + ShapeUtil::MakeTupleShape({keys, values_s32, values_u32}))); +} + class ScatterGatherShapeInferenceTest : public ShapeInferenceTest { protected: const Shape s64_scalar_ = ShapeUtil::MakeShape(S64, {}); diff --git a/tensorflow/compiler/xla/service/shaped_buffer.cc b/tensorflow/compiler/xla/service/shaped_buffer.cc index 5c12dc37b73f92ade419604bfedac55e35fa9f3f..56952e3adae59656605a12fd499162504a2a3379 100644 --- a/tensorflow/compiler/xla/service/shaped_buffer.cc +++ b/tensorflow/compiler/xla/service/shaped_buffer.cc @@ -18,21 +18,19 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/gtl/flatset.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" namespace xla { -using ::tensorflow::strings::Appendf; - ShapedBuffer::ShapedBuffer(const Shape& on_host_shape, const Shape& on_device_shape, const se::Platform* platform, int device_ordinal) @@ -93,9 +91,9 @@ string ShapedBuffer::ToString() const { shape_str = ShapeUtil::HumanStringWithLayout(subshape); } const se::DeviceMemoryBase& memory = buffer(index); - Appendf(&s, " %s%p (%lld bytes) : %s\n", - string(index.size() * 2, ' ').c_str(), memory.opaque(), - memory.size(), shape_str.c_str()); + absl::StrAppendFormat(&s, " %s%p (%d bytes) : %s\n", + string(index.size() * 2, ' '), memory.opaque(), + memory.size(), shape_str); }); return s; } @@ -149,7 +147,7 @@ void ScopedShapedBuffer::Deallocate() { // Deallocate all non-null buffers. A buffer may appear in more than one spot // in the shape (eg, a tuple with a repeated element) so keep track of what // has been deallocated. - tensorflow::gtl::FlatSet deallocated_ptrs; + absl::flat_hash_set deallocated_ptrs; for (auto& pair : buffers_) { se::DeviceMemoryBase& memory_base = pair.second; if (!memory_base.is_null() && diff --git a/tensorflow/compiler/xla/service/shaped_buffer.h b/tensorflow/compiler/xla/service/shaped_buffer.h index 905a7e82e621f2bf4588b71be5dbab20f892cafe..e1d26da4a20c0105be304b1a34c81515fcdc6b7f 100644 --- a/tensorflow/compiler/xla/service/shaped_buffer.h +++ b/tensorflow/compiler/xla/service/shaped_buffer.h @@ -20,11 +20,11 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/shape_tree.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/types.h" @@ -84,6 +84,14 @@ class ShapedBuffer { *buffers_.mutable_element(index) = buffer; } + // Sets all buffers. + // + // Precondition: buffers.shape == on_device_shape_ + void set_buffers(ShapeTree buffers) { + CHECK(ShapeUtil::Equal(buffers.shape(), on_device_shape_)); + buffers_ = std::move(buffers); + } + // Returns the underlying ShapeTree containing all the device addresses in the // ShapedBuffer. const ShapeTree& buffers() const { return buffers_; } diff --git a/tensorflow/compiler/xla/service/source_map_util.cc b/tensorflow/compiler/xla/service/source_map_util.cc deleted file mode 100644 index 8cbaac7b3760717bcacb57adc8782a5755c0aa6d..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/source_map_util.cc +++ /dev/null @@ -1,66 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/source_map_util.h" - -#include "tensorflow/compiler/xla/util.h" - -namespace xla { -namespace source_map_util { -namespace { - -Status InvalidParameterArgumentV(const OpMetadata& op_metadata, - const char* format, va_list args) { - string message; - tensorflow::strings::Appendv(&message, format, args); - if (!op_metadata.source_file().empty()) { - tensorflow::strings::Appendf(&message, " (%s:%d)", - op_metadata.source_file().c_str(), - op_metadata.source_line()); - } - return InvalidArgument("%s", message.c_str()); -} - -} // namespace - -Status InvalidParameterArgument(const OpMetadata& op_metadata, - const char* format, ...) { - va_list args; - va_start(args, format); - Status result = InvalidParameterArgumentV(op_metadata, format, args); - va_end(args); - return result; -} - -Status InvalidParameterArgument(Executable* executable, int parameter_number, - const char* format, ...) { - va_list args; - va_start(args, format); - if (executable != nullptr && executable->has_module()) { - const HloModule& module = executable->module(); - const HloComputation& computation = *module.entry_computation(); - HloInstruction* param = computation.parameter_instruction(parameter_number); - const OpMetadata& metadata = param->metadata(); - Status result = InvalidParameterArgumentV(metadata, format, args); - va_end(args); - return result; - } - Status result = InvalidArgumentV(format, args); - va_end(args); - return result; -} - -} // namespace source_map_util -} // namespace xla diff --git a/tensorflow/compiler/xla/service/source_map_util.h b/tensorflow/compiler/xla/service/source_map_util.h index 84607cd012a9cff4eee5759b4235b2563692f84f..c5a7e17cb44c2b3b5ef145da0d66b4b3160f9531 100644 --- a/tensorflow/compiler/xla/service/source_map_util.h +++ b/tensorflow/compiler/xla/service/source_map_util.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_SOURCE_MAP_UTIL_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_SOURCE_MAP_UTIL_H_ +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/service/executable.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/core/platform/macros.h" @@ -23,6 +24,19 @@ limitations under the License. namespace xla { namespace source_map_util { +// Creates an INVALID_ARGUMENT status with the given format string. +template +Status InvalidParameterArgument(const OpMetadata& op_metadata, + const absl::FormatSpec& format, + const Args&... args) { + string message = absl::StrFormat(format, args...); + if (!op_metadata.source_file().empty()) { + absl::StrAppendFormat(&message, " (%s:%d)", op_metadata.source_file(), + op_metadata.source_line()); + } + return InvalidArgument("%s", message); +} + // Creates an INVALID_ARGUMENT status with the given format string. // // Also, attempts to extract the OpMetadata for parameter_number on executable @@ -30,15 +44,19 @@ namespace source_map_util { // // executable may be nullptr, but parameter_number should not be out of bounds // or a CHECK-failure may occur. +template Status InvalidParameterArgument(Executable* executable, int parameter_number, - const char* format, ...) - TF_PRINTF_ATTRIBUTE(3, 4); - -// As above, but takes the parameter metadata directly instead of extracting it -// from the executable. -Status InvalidParameterArgument(const OpMetadata& op_metadata, - const char* format, ...) - TF_PRINTF_ATTRIBUTE(2, 3); + const absl::FormatSpec& format, + const Args&... args) { + if (executable != nullptr && executable->has_module()) { + const HloModule& module = executable->module(); + const HloComputation& computation = *module.entry_computation(); + HloInstruction* param = computation.parameter_instruction(parameter_number); + const OpMetadata& metadata = param->metadata(); + return InvalidParameterArgument(metadata, format, args...); + } + return InvalidArgument(format, args...); +} } // namespace source_map_util } // namespace xla diff --git a/tensorflow/compiler/xla/service/stream_pool.cc b/tensorflow/compiler/xla/service/stream_pool.cc index 5d1cd1c4422a10e3b9e6ce6fac2c83594bb58b30..ec09dff9244080d24580cad8ee2359a34a6a4f96 100644 --- a/tensorflow/compiler/xla/service/stream_pool.cc +++ b/tensorflow/compiler/xla/service/stream_pool.cc @@ -28,8 +28,14 @@ StreamPool::Ptr StreamPool::BorrowStream(se::StreamExecutor* executor) { // Re-use an existing stream from the pool. stream = std::move(streams_.back()); streams_.pop_back(); - VLOG(1) << stream->DebugStreamPointers() - << " StreamPool reusing existing stream"; + if (stream->ok()) { + VLOG(1) << stream->DebugStreamPointers() + << " StreamPool reusing existing stream"; + } else { + VLOG(1) << stream->DebugStreamPointers() + << " stream was not ok, StreamPool deleting"; + stream = nullptr; + } } } diff --git a/tensorflow/compiler/xla/service/stream_pool_test.cc b/tensorflow/compiler/xla/service/stream_pool_test.cc index aaf5c37b0d250f78cb57639255ac9b59e1b462f7..92f47579d31303b39f6f3a1859789588b586db87 100644 --- a/tensorflow/compiler/xla/service/stream_pool_test.cc +++ b/tensorflow/compiler/xla/service/stream_pool_test.cc @@ -132,5 +132,39 @@ TEST_F(StreamPoolTest, BadStreamDiscarded) { EXPECT_EQ(stream2_ptr, stream3_ptr); } +TEST_F(StreamPoolTest, BadStreamAfterReturnDiscarded) { + std::unique_ptr executor = NewStreamExecutor(); + StreamPool pool; + + // Borrow a stream. + StreamPool::Ptr stream1 = pool.BorrowStream(executor.get()); + EXPECT_TRUE(stream1->ok()); + + // Return the stream, but hold a handle to it. + se::Stream* stream1_ptr = stream1.get(); + stream1 = nullptr; + + // Now stream1 is back in the pool, force an error on the stream. Here we call + // a method that requires DNN support, which we know the Host platform doesn't + // support. + stream1_ptr->ThenDepthConcatenate({}, {}, nullptr); + EXPECT_FALSE(stream1_ptr->ok()); + + // Borrow stream2. + StreamPool::Ptr stream2 = pool.BorrowStream(executor.get()); + EXPECT_TRUE(stream2->ok()); + + // The underlying streams should be different. They would have been + // the same, but since we forced an error on stream1, it cannot be + // put back into the pool. Sadly we can't just check: + // EXPECT_NE(stream1_ptr, stream2_ptr); + // + // The above should hold logically, but it may fail if the new + // stream instance allocated for stream2 happens to reside in the + // same memory address as stream1, which has been deleted. + // + // The check that stream2->ok() serves as a good-enough check. +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/transfer_manager.cc b/tensorflow/compiler/xla/service/transfer_manager.cc index 0c577ec67a2bbc18f99ae118c15753bd4f3687f9..a21e586efadb85d18e88e44999283b28f7f65eac 100644 --- a/tensorflow/compiler/xla/service/transfer_manager.cc +++ b/tensorflow/compiler/xla/service/transfer_manager.cc @@ -42,9 +42,9 @@ TransferManager::GetPlatformTransferManagers() { return r; } -StatusOr> TransferManager::TransferLiteralFromDevice( +StatusOr TransferManager::TransferLiteralFromDevice( se::Stream* stream, const ShapedBuffer& device_buffer) { - StatusOr> ret; + StatusOr ret; se::Stream* substream = stream->GetOrCreateSubStream(); substream->ThenWaitFor(stream); @@ -63,7 +63,7 @@ StatusOr> TransferManager::TransferLiteralFromDevice( if (!s.ok()) { return s; } - return absl::make_unique(std::move(literal)); + return std::move(literal); } Status TransferManager::TransferLiteralFromDevice( @@ -99,10 +99,10 @@ Status TransferManager::TransferLiteralToDevice( return substream->BlockHostUntilDone(); } -StatusOr> TransferManager::TransferArrayFromDevice( +StatusOr TransferManager::TransferArrayFromDevice( se::Stream* stream, const Shape& shape, const se::DeviceMemoryBase& source) { - StatusOr> ret; + StatusOr ret; // Implement the synchronous version by waiting on the asynchronous version. // Use a substream so that if we are called from a HostCallback we don't // deadlock. @@ -122,7 +122,7 @@ StatusOr> TransferManager::TransferArrayFromDevice( if (!s.ok()) { return s; } - return absl::make_unique(std::move(literal)); + return std::move(literal); } Status TransferManager::TransferArrayToDevice( @@ -149,7 +149,7 @@ Status TransferManager::TransferArrayToDeviceAsync( if (dest.size() < GetByteSizeRequirement(on_device_shape)) { return FailedPrecondition( "Allocation on device not large enough for array: " - "%lld < %lld", + "%d < %d", dest.size(), GetByteSizeRequirement(on_device_shape)); } ShapedBuffer shaped_buffer(/*on_host_shape=*/literal.shape(), on_device_shape, @@ -166,12 +166,12 @@ void TransferManager::TransferArrayFromDevice( auto error = StrCat("Shape ", ShapeUtil::HumanString(shape), " has a differently shaped representation on-device: ", ShapeUtil::HumanString(HostShapeToDeviceShape(shape))); - return done(FailedPrecondition("%s", error.c_str())); + return done(FailedPrecondition("%s", error)); } if (source.size() < GetByteSizeRequirement(shape)) { return done( FailedPrecondition("Allocation on device not large enough for array: " - "%lld < %lld", + "%d < %d", source.size(), GetByteSizeRequirement(shape))); } ShapedBuffer shaped_buffer(/*on_host_shape=*/shape, shape, @@ -203,7 +203,7 @@ void TransferManager::TransferArrayFromDevice( return NotFound( "could not find registered transfer manager for platform %s -- check " "target linkage", - platform->Name().c_str()); + platform->Name()); } if (it->second.manager == nullptr) { @@ -254,7 +254,7 @@ Status TransferManager::TransferBufferFromDevice( if (source.size() < size) { return FailedPrecondition( "Source allocation on device not large enough for data tranfer: " - "%lld < %lld", + "%d < %d", source.size(), size); } stream->ThenMemcpy(destination, source, size); @@ -267,7 +267,7 @@ Status TransferManager::TransferBufferToDevice( if (destination->size() < size) { return FailedPrecondition( "Destination allocation on device not large enough for data tranfer: " - "%lld < %lld", + "%d < %d", destination->size(), size); } stream->ThenMemcpy(destination, source, size); @@ -278,9 +278,8 @@ StatusOr TransferManager::AllocateScopedShapedBuffer( const Shape& on_host_shape, DeviceMemoryAllocator* allocator, int device_ordinal) { if (!LayoutUtil::HasLayout(on_host_shape)) { - return InvalidArgument( - "Shape must have a layout: %s", - ShapeUtil::HumanStringWithLayout(on_host_shape).c_str()); + return InvalidArgument("Shape must have a layout: %s", + ShapeUtil::HumanStringWithLayout(on_host_shape)); } TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(on_host_shape)); const Shape on_device_shape = HostShapeToDeviceShape(on_host_shape); diff --git a/tensorflow/compiler/xla/service/transfer_manager.h b/tensorflow/compiler/xla/service/transfer_manager.h index f77690a46215e7f9e16f89f85f07e93e37417c35..f952e64af2b675b9c0f8a30e9a2bc3c855e34efa 100644 --- a/tensorflow/compiler/xla/service/transfer_manager.h +++ b/tensorflow/compiler/xla/service/transfer_manager.h @@ -20,12 +20,12 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/thread_annotations.h" @@ -57,7 +57,7 @@ class TransferManager { // without waiting for any other operation on a stream to complete. // // This function should be avoided in favor of the asynchronous version below. - virtual StatusOr> TransferLiteralFromDevice( + virtual StatusOr TransferLiteralFromDevice( se::Stream* stream, const ShapedBuffer& device_buffer); virtual Status TransferLiteralFromDevice( se::Stream* stream, const ShapedBuffer& device_buffer, @@ -113,9 +113,9 @@ class TransferManager { Status TransferArrayToDeviceAsync(se::Stream* stream, const LiteralSlice& literal, const se::DeviceMemoryBase& dest); - StatusOr> TransferArrayFromDevice( - se::Stream* stream, const Shape& shape, - const se::DeviceMemoryBase& source); + StatusOr TransferArrayFromDevice(se::Stream* stream, + const Shape& shape, + const se::DeviceMemoryBase& source); // Transfers the given literal into the Infeed interface of the device, // using the given executor. @@ -130,7 +130,7 @@ class TransferManager { // Resets the devices associated with this transfer manager. virtual Status ResetDevices( - tensorflow::gtl::ArraySlice executor) = 0; + absl::Span executor) = 0; // Given an allocated ShapedBuffer, constructs the tuple index table(s) in // each buffer of the given ShapedBuffer corresponding to tuple shapes. If the @@ -211,8 +211,7 @@ class TransferManager { // to construct a tuple index table in the platform-specific tuple // representation. virtual Status WriteSingleTupleIndexTable( - se::Stream* stream, - tensorflow::gtl::ArraySlice elements, + se::Stream* stream, absl::Span elements, const Shape& shape, se::DeviceMemoryBase* region) = 0; private: diff --git a/tensorflow/compiler/xla/service/transpose_folding.cc b/tensorflow/compiler/xla/service/transpose_folding.cc index 530f40e4b2f9c7c19fa29dad28a077b9d4d68a71..7c1f4b5cc67dd2a84271b4f2b8015fdb2ff6e846 100644 --- a/tensorflow/compiler/xla/service/transpose_folding.cc +++ b/tensorflow/compiler/xla/service/transpose_folding.cc @@ -108,8 +108,7 @@ Status FoldTransposeIntoDot(InstructionOperandsPair pair) { } std::unique_ptr new_dot = HloInstruction::CreateDot( - dot->shape(), new_lhs, new_rhs, new_dim_numbers); - new_dot->set_precision_config(dot->precision_config()); + dot->shape(), new_lhs, new_rhs, new_dim_numbers, dot->precision_config()); return dot->parent()->ReplaceWithNewInstruction(dot, std::move(new_dot)); } @@ -178,8 +177,8 @@ bool FoldTransposeIntoConvolution(InstructionOperandsPair pair) { } auto new_conv = HloInstruction::CreateConvolve( - convolution.shape(), new_lhs, new_rhs, convolution.window(), new_dnums); - new_conv->set_precision_config(convolution.precision_config()); + convolution.shape(), new_lhs, new_rhs, convolution.feature_group_count(), + convolution.window(), new_dnums, convolution.precision_config()); TF_CHECK_OK(convolution.parent()->ReplaceWithNewInstruction( &convolution, std::move(new_conv))); diff --git a/tensorflow/compiler/xla/service/transpose_folding.h b/tensorflow/compiler/xla/service/transpose_folding.h index 3e5aa2db60ee31d9fbccf8f7256b15c1b8465335..f95f982eb89d60884b652cd832dff0363372369c 100644 --- a/tensorflow/compiler/xla/service/transpose_folding.h +++ b/tensorflow/compiler/xla/service/transpose_folding.h @@ -23,7 +23,7 @@ namespace xla { // HLO pass that folds transpose operators into Dot operators, where the Dot // operator is implemented by a GEMM kernel that can transpose its inputs. -class TransposeFolding : public HloPassInterface { +class TransposeFolding : public HloModulePass { public: using OperandIndices = std::vector; diff --git a/tensorflow/compiler/xla/service/transpose_folding_test.cc b/tensorflow/compiler/xla/service/transpose_folding_test.cc index 58f767e913fbc0023e0c45a4f0e82ecefeeef2d6..79b5c09abb355cd067a4891af558c8c44d80d88e 100644 --- a/tensorflow/compiler/xla/service/transpose_folding_test.cc +++ b/tensorflow/compiler/xla/service/transpose_folding_test.cc @@ -240,10 +240,12 @@ TEST_F(TransposeFoldingTest, FoldConvDimSwapTransposeRhs) { transpose_y->shape().dimensions(dnums.kernel_spatial_dimensions(i))); } StatusOr conv_shape = ShapeInference::InferConvolveShape( - x->shape(), transpose_y->shape(), window, dnums); + x->shape(), transpose_y->shape(), /*feature_group_count=*/1, window, + dnums); EXPECT_IS_OK(conv_shape); HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( - conv_shape.ValueOrDie(), x, transpose_y, window, dnums)); + conv_shape.ValueOrDie(), x, transpose_y, + /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); auto module = CreateNewModule("test_module"); HloComputation* entry_computation = @@ -293,10 +295,12 @@ TEST_F(TransposeFoldingTest, FoldConvComplexTransposeRhs) { transpose_y->shape().dimensions(dnums.kernel_spatial_dimensions(i))); } StatusOr conv_shape = ShapeInference::InferConvolveShape( - x->shape(), transpose_y->shape(), window, dnums); + x->shape(), transpose_y->shape(), /*feature_group_count=*/1, window, + dnums); EXPECT_IS_OK(conv_shape); HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( - conv_shape.ValueOrDie(), x, transpose_y, window, dnums)); + conv_shape.ValueOrDie(), x, transpose_y, + /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); auto module = CreateNewModule("test_module"); HloComputation* entry_computation = @@ -351,10 +355,12 @@ TEST_F(TransposeFoldingTest, FoldConvTransposeLhs) { dim->set_size(y->shape().dimensions(dnums.kernel_spatial_dimensions(i))); } StatusOr conv_shape = ShapeInference::InferConvolveShape( - transpose_x->shape(), y->shape(), window, dnums); + transpose_x->shape(), y->shape(), /*feature_group_count=*/1, window, + dnums); EXPECT_IS_OK(conv_shape); HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( - conv_shape.ValueOrDie(), transpose_x, y, window, dnums)); + conv_shape.ValueOrDie(), transpose_x, y, + /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); auto module = CreateNewModule("test_module"); HloComputation* entry_computation = @@ -415,10 +421,12 @@ TEST_F(TransposeFoldingTest, FoldConvComplexTransposeLhs) { dim->set_size(y->shape().dimensions(dnums.kernel_spatial_dimensions(i))); } StatusOr conv_shape = ShapeInference::InferConvolveShape( - transpose_x->shape(), y->shape(), window, dnums); + transpose_x->shape(), y->shape(), /*feature_group_count=*/1, window, + dnums); EXPECT_IS_OK(conv_shape); HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( - conv_shape.ValueOrDie(), transpose_x, y, window, dnums)); + conv_shape.ValueOrDie(), transpose_x, y, + /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); auto module = CreateNewModule("test_module"); HloComputation* entry_computation = diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc index cb07b8d4d31ae1e11ea82f60c56c841ca37295cf..96f3055c98e0611dfe25517cb490014a6d1f7c76 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc @@ -21,6 +21,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" @@ -29,7 +30,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" namespace xla { @@ -148,7 +148,7 @@ TuplePointsToAnalysis::Run(const HloModule* module) { Status TuplePointsToAnalysis::Analyze() { per_instruction_.clear(); - per_instruction_.resize(module_->NumUniqueInstructionIds()); + per_instruction_.reserve(module_->instruction_count()); logical_buffer_aliases_.clear(); logical_buffer_aliases_.resize( @@ -280,16 +280,6 @@ Status TuplePointsToAnalysis::HandleDomain(HloInstruction* domain) { return Status::OK(); } -Status TuplePointsToAnalysis::HandleSlice(HloInstruction* slice) { - // A kSlice instruction aliases its operand if the backend lowers it to an - // in-place implementation. - if (slice->IsInPlaceSlice()) { - CreateCopiedPointsToSet(slice, slice->operand(0)); - return Status::OK(); - } - return DefaultAction(slice); -} - Status TuplePointsToAnalysis::HandleRecvDone(HloInstruction* recv_done) { // RecvDone aliases its input (Recv) tuple element {0} to element {0} of its // output. The other indices ({} and {1}) define their own buffers. @@ -360,7 +350,7 @@ Status TuplePointsToAnalysis::HandleSend(HloInstruction* send) { } Status TuplePointsToAnalysis::HandleTuple(HloInstruction* tuple) { - tensorflow::gtl::ArraySlice operands(tuple->operands()); + absl::Span operands(tuple->operands()); PointsToSet& points_to_set = CreateEmptyPointsToSet(tuple); points_to_set.AddPointedToBuffer( logical_buffer_analysis_->GetBuffer(tuple, /*index=*/{}), @@ -455,28 +445,22 @@ bool TuplePointsToAnalysis::InstructionDefinesBufferAtIndex( Status TuplePointsToAnalysis::VerifyBuffer(const LogicalBuffer& buffer) const { if (!InstructionDefinesBufferAtIndex(buffer.instruction(), buffer.index())) { - // kSlice ops that are lowered to an in-place version are expected to not - // define their output buffer. - if (buffer.instruction()->opcode() != HloOpcode::kSlice || - !buffer.instruction()->IsInPlaceSlice()) { - return FailedPrecondition( - "LogicalBuffer %s is ill-defined: instruction %s does not define a " - "buffer at that index", - buffer.ToString().c_str(), buffer.instruction()->name().c_str()); - } + return FailedPrecondition( + "LogicalBuffer %s is ill-defined: instruction %s does not define a " + "buffer at that index", + buffer.ToString(), buffer.instruction()->name()); } if (buffer.id() < 0 || buffer.id() >= logical_buffer_analysis_->num_logical_buffers()) { - return FailedPrecondition( - "LogicalBuffer %s is ill-defined: invalid id %lld", - buffer.ToString().c_str(), buffer.id()); + return FailedPrecondition("LogicalBuffer %s is ill-defined: invalid id %d", + buffer.ToString(), buffer.id()); } if (GetBuffer(buffer.id()).instruction() != buffer.instruction() || GetBuffer(buffer.id()).index() != buffer.index()) { return FailedPrecondition( "LogicalBuffer %s is ill-defined: buffer with same id differs: %s", - buffer.ToString().c_str(), GetBuffer(buffer.id()).ToString().c_str()); + buffer.ToString(), GetBuffer(buffer.id()).ToString()); } return Status::OK(); @@ -495,7 +479,7 @@ StatusOr TuplePointsToAnalysis::GetBufferDefinedAt( if (buffers.size() != 1 || buffers[0]->instruction() != instruction) { return FailedPrecondition( "instruction %s does not define buffer at index {%s}", - instruction->name().c_str(), absl::StrJoin(index, ",").c_str()); + instruction->name(), absl::StrJoin(index, ",")); } return buffers[0]; } @@ -556,8 +540,8 @@ PointsToSet& TuplePointsToAnalysis::CreateCopiedPointsToSet( } string TuplePointsToAnalysis::ToString() const { - string output = tensorflow::strings::Printf( - "TuplePointsToSet for module %s:\n", module_->name().c_str()); + string output = + absl::StrFormat("TuplePointsToSet for module %s:\n", module_->name()); for (const auto* computation : module_->MakeNonfusionComputations()) { const char* entry = computation == module_->entry_computation() ? "entry " : ""; @@ -772,6 +756,7 @@ bool TuplePointsToAnalysis::CanShareOperandBufferWithUser( } } if (user->opcode() == HloOpcode::kDynamicUpdateSlice || + user->opcode() == HloOpcode::kScatter || user->opcode() == HloOpcode::kWhile) { // We eliminated other users in BufferLiveness::live_range_strictly_before, // so here we just need to check that the use is at operand index 0. diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.h b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h index 62c7bb685dfea0fa91c06b9700dc9f54d70f429e..bcfcb388f95b0bedb35a8c399e804034816867b3 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.h +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h @@ -23,7 +23,9 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" #include "absl/container/inlined_vector.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" @@ -34,10 +36,7 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/gtl/compactptrset.h" -#include "tensorflow/core/lib/gtl/flatmap.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" @@ -249,7 +248,6 @@ class TuplePointsToAnalysis : public DfsHloVisitorWithDefault { Status HandleGetTupleElement(HloInstruction* get_tuple_element) override; Status HandleBitcast(HloInstruction* bitcast) override; Status HandleDomain(HloInstruction* domain) override; - Status HandleSlice(HloInstruction* slice) override; Status HandleCopy(HloInstruction* copy) override; Status HandleRecvDone(HloInstruction* recv_done) override; Status HandleSend(HloInstruction* send) override; @@ -318,14 +316,23 @@ class TuplePointsToAnalysis : public DfsHloVisitorWithDefault { const PerInstruction* PerInst(const HloInstruction* inst) const { int id = inst->unique_id(); DCHECK_GE(id, 0); - DCHECK_LT(id, per_instruction_.size()); - return &per_instruction_[id]; + auto iter = per_instruction_.find(id); + if (iter == per_instruction_.end()) { + LOG(FATAL) << "Expected per-instruction information to already exist"; + } else { + return iter->second.get(); + } } PerInstruction* PerInst(const HloInstruction* inst) { int id = inst->unique_id(); DCHECK_GE(id, 0); - DCHECK_LT(id, per_instruction_.size()); - return &per_instruction_[id]; + auto iter = per_instruction_.find(id); + if (iter == per_instruction_.end()) { + return per_instruction_.emplace(id, absl::make_unique()) + .first->second.get(); + } else { + return iter->second.get(); + } } std::vector> GetAllUsesOfInstructionAtIndex( @@ -342,7 +349,7 @@ class TuplePointsToAnalysis : public DfsHloVisitorWithDefault { const std::unique_ptr logical_buffer_analysis_; // A map from instruction->unique_id() to - std::vector per_instruction_; + absl::flat_hash_map> per_instruction_; // A map from LogicalBuffer->id() to alias information about that logical // buffer diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc index 10d382e8abc92145c1804cbf18bbed714fa34571..d9ebebf74ed846aa05326a4df72019ef3e71ad88 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc @@ -72,9 +72,8 @@ class TuplePointsToAnalysisTest : public HloTestBase { // Checks that the given points-to set contains exactly (unordered) the given // LogicalBuffers. - void ExpectHasBuffers( - const PointsToSet::BufferList& points_to_set, - tensorflow::gtl::ArraySlice buffers) { + void ExpectHasBuffers(const PointsToSet::BufferList& points_to_set, + absl::Span buffers) { std::vector vec(buffers.begin(), buffers.end()); EXPECT_THAT(points_to_set, UnorderedElementsAreArray(vec)); } @@ -83,7 +82,7 @@ class TuplePointsToAnalysisTest : public HloTestBase { // top-level buffers of the given instructions. void ExpectHasTopLevelBuffers( const PointsToSet::BufferList& points_to_set, - tensorflow::gtl::ArraySlice instructions) { + absl::Span instructions) { PointsToSet::BufferList buffers; for (auto instruction : instructions) { buffers.push_back(GetBuffer(instruction, /*index=*/{})); @@ -94,7 +93,7 @@ class TuplePointsToAnalysisTest : public HloTestBase { // Overload which takes a set instead of a vector. void ExpectHasTopLevelBuffers( const PointsToSet::BufferSet& points_to_set, - tensorflow::gtl::ArraySlice instructions) { + absl::Span instructions) { ExpectHasTopLevelBuffers( PointsToSet::BufferList(points_to_set.begin(), points_to_set.end()), instructions); @@ -104,8 +103,7 @@ class TuplePointsToAnalysisTest : public HloTestBase { // aliases which are exactly (unordered) the given instruction/index pairs. void ExpectHasBufferAliases( const HloInstruction* instruction, const ShapeIndex& index, - tensorflow::gtl::ArraySlice> - expected) { + absl::Span> expected) { const LogicalBuffer* buffer = points_to_analysis_->GetBufferDefinedAt(instruction, index) .ValueOrDie(); @@ -557,10 +555,10 @@ TEST_F(TuplePointsToAnalysisTest, PointsToTupleConstantElements) { // Construct a tuple constant and kCopy it. Verify the points-to set of the // copy correctly correctly points into the nested elements of the constant. auto builder = HloComputation::Builder(TestName()); - auto tuple_constant = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::MakeTuple( - {LiteralUtil::CreateR2({{1.0}, {2.0}}).get(), - LiteralUtil::CreateR1({2.0, 42}).get()}))); + Literal elements[] = {LiteralUtil::CreateR2({{1.0}, {2.0}}), + LiteralUtil::CreateR1({2.0, 42})}; + auto tuple_constant = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::MakeTuple({&elements[0], &elements[1]}))); auto copy = builder.AddInstruction(HloInstruction::CreateUnary( tuple_constant->shape(), HloOpcode::kCopy, tuple_constant)); @@ -1012,6 +1010,44 @@ TEST_F(CanShareOperandBufferWithUserTest, DynamicUpdateSliceCanShare) { points_to_analysis_->CanShareOperandBufferWithUser(starts, {}, dus, {})); } +TEST_F(CanShareOperandBufferWithUserTest, ScatterCanShare) { + const char* hlo_text = R"( + HloModule TensorFlowScatterV1 + + update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) + } + + ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + updates = s32[2,3] parameter(2) + ROOT scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 + } + )"; + TF_ASSERT_OK_AND_ASSIGN(module_, ParseHloString(hlo_text)); + computation_ = module_->entry_computation(); + RunAnalysis(); + + HloInstruction* operand_param = computation_->parameter_instruction(0); + HloInstruction* indices_param = computation_->parameter_instruction(1); + HloInstruction* updates_param = computation_->parameter_instruction(2); + HloInstruction* scatter = computation_->root_instruction(); + + EXPECT_TRUE(points_to_analysis_->CanShareOperandBufferWithUser( + operand_param, {}, scatter, {})); + EXPECT_FALSE(points_to_analysis_->CanShareOperandBufferWithUser( + indices_param, {}, scatter, {})); + EXPECT_FALSE(points_to_analysis_->CanShareOperandBufferWithUser( + updates_param, {}, scatter, {})); +} + TEST_F(CanShareOperandBufferWithUserTest, SortCanShare) { auto builder = HloComputation::Builder(TestName()); @@ -1037,7 +1073,8 @@ TEST_F(CanShareOperandBufferWithUserTest, SortCanShareWithTupleUser) { auto values = builder.AddInstruction( HloInstruction::CreateParameter(1, values_shape, "values")); auto sort = builder.AddInstruction(HloInstruction::CreateSort( - ShapeUtil::MakeTupleShape({keys_shape, values_shape}), 0, keys, values)); + ShapeUtil::MakeTupleShape({keys_shape, values_shape}), 0, keys, + {values})); BuildModuleAndRunAnalysis(builder.Build()); @@ -1066,8 +1103,11 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); + PrecisionConfig precision_config; + precision_config.mutable_operand_precision()->Resize( + /*new_size=*/2, PrecisionConfig::DEFAULT); auto dot = builder.AddInstruction( - HloInstruction::CreateDot(data_shape, a, b, dot_dnums)); + HloInstruction::CreateDot(data_shape, a, b, dot_dnums, precision_config)); auto one = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); diff --git a/tensorflow/compiler/xla/service/tuple_simplifier.h b/tensorflow/compiler/xla/service/tuple_simplifier.h index 8c91d6e69de637d58fa2ffc1a32ea65f09d3b6d8..e126a530234c1452bcf91f642f63d4c087935a56 100644 --- a/tensorflow/compiler/xla/service/tuple_simplifier.h +++ b/tensorflow/compiler/xla/service/tuple_simplifier.h @@ -25,7 +25,7 @@ namespace xla { // A pass which simplifies patterns of Tuple and GetTupleElement instructions in // the module. -class TupleSimplifier : public HloPassInterface { +class TupleSimplifier : public HloModulePass { public: TupleSimplifier() : TupleSimplifier(/*exclude_entry_computation=*/false) {} explicit TupleSimplifier(bool exclude_entry_computation); diff --git a/tensorflow/compiler/xla/service/tuple_simplifier_test.cc b/tensorflow/compiler/xla/service/tuple_simplifier_test.cc index 39b693872da6bd985d95c2abc9519662c838a3f5..516754e2110ee50a597818c4a8bcfbfbb76c5cec 100644 --- a/tensorflow/compiler/xla/service/tuple_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/tuple_simplifier_test.cc @@ -25,7 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -34,7 +34,7 @@ namespace op = xla::testing::opcode_matchers; namespace xla { namespace { -class TupleSimplifierTest : public HloTestBase { +class TupleSimplifierTest : public HloVerifiedTestBase { protected: void Run(HloModule* module, bool change_expected) { TupleSimplifier simplifier; @@ -68,7 +68,7 @@ TEST_F(TupleSimplifierTest, TupleOfParameters) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - Run(module.get(), /*change_expected=*/false); + Run(module, /*change_expected=*/false); } TEST_F(TupleSimplifierTest, GteOfTupleOfParameter) { @@ -81,7 +81,7 @@ TEST_F(TupleSimplifierTest, GteOfTupleOfParameter) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - Run(module.get(), /*change_expected=*/false); + Run(module, /*change_expected=*/false); } TEST_F(TupleSimplifierTest, GteOfTuple) { @@ -103,7 +103,7 @@ TEST_F(TupleSimplifierTest, GteOfTuple) { EXPECT_THAT(computation->root_instruction(), gte); - Run(module.get(), /*change_expected=*/true); + Run(module, /*change_expected=*/true); EXPECT_THAT(computation->root_instruction(), param1); } @@ -131,7 +131,7 @@ TEST_F(TupleSimplifierTest, GteOfTupleChain) { EXPECT_THAT(computation->root_instruction(), op::Negate(op::GetTupleElement(op::Tuple()))); - Run(module.get(), /*change_expected=*/true); + Run(module, /*change_expected=*/true); EXPECT_THAT(computation->root_instruction(), op::Negate(op::Parameter())); } @@ -162,7 +162,7 @@ TEST_F(TupleSimplifierTest, NestedGteOfTuples) { EXPECT_THAT(computation->root_instruction(), element); - Run(module.get(), /*change_expected=*/true); + Run(module, /*change_expected=*/true); EXPECT_THAT(computation->root_instruction(), param); } @@ -187,7 +187,7 @@ TEST_F(TupleSimplifierTest, TupleOfGteInstructions) { EXPECT_THAT(computation->root_instruction(), tuple); - Run(module.get(), /*change_expected=*/true); + Run(module, /*change_expected=*/true); EXPECT_THAT(computation->root_instruction(), tuple_param); } @@ -212,7 +212,7 @@ TEST_F(TupleSimplifierTest, IncompatibleTuples) { EXPECT_THAT(computation->root_instruction(), tuple); - Run(module.get(), /*change_expected=*/false); + Run(module, /*change_expected=*/false); EXPECT_THAT(computation->root_instruction(), tuple); } @@ -281,7 +281,7 @@ TEST_F(TupleSimplifierTest, CanExcludeEntryComputation) { entry = module->AddEntryComputation(builder.Build()); } - Run(module.get(), /*change_expected=*/true, /*exclude_entry=*/ true); + Run(module, /*change_expected=*/true, /*exclude_entry=*/true); EXPECT_THAT(c0->root_instruction(), p0); EXPECT_THAT(c1->root_instruction(), p1); diff --git a/tensorflow/compiler/xla/service/tuple_util.cc b/tensorflow/compiler/xla/service/tuple_util.cc index 4a530bb0b20582b303f4af969514748b46fd5064..cfb0c787d09557fd1aec3517eb9698cfec323369 100644 --- a/tensorflow/compiler/xla/service/tuple_util.cc +++ b/tensorflow/compiler/xla/service/tuple_util.cc @@ -14,8 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/tuple_util.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/core/lib/gtl/array_slice.h" namespace xla { @@ -40,7 +40,7 @@ namespace xla { /*static*/ HloInstruction* TupleUtil::AppendSuffix( HloInstruction* input_tuple, - tensorflow::gtl::ArraySlice trailing_values) { + absl::Span trailing_values) { CHECK(ShapeUtil::IsTuple(input_tuple->shape())); HloComputation* computation = input_tuple->parent(); diff --git a/tensorflow/compiler/xla/service/tuple_util.h b/tensorflow/compiler/xla/service/tuple_util.h index e5ff9aaa8357fe8e4777d6dee37bbec72e144c06..bc5aac09f270c01515b1f3a704af6949f24cb218 100644 --- a/tensorflow/compiler/xla/service/tuple_util.h +++ b/tensorflow/compiler/xla/service/tuple_util.h @@ -38,7 +38,7 @@ class TupleUtil { // `input_tuple`. static HloInstruction* AppendSuffix( HloInstruction* input_tuple, - tensorflow::gtl::ArraySlice trailing_values); + absl::Span trailing_values); }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/while_loop_analysis.cc b/tensorflow/compiler/xla/service/while_loop_analysis.cc index 7e4ac92a7c5d1e75fbff586e6891cfbef86347c2..541b117e0299c94de330604ec5c16e20f07c425f 100644 --- a/tensorflow/compiler/xla/service/while_loop_analysis.cc +++ b/tensorflow/compiler/xla/service/while_loop_analysis.cc @@ -183,8 +183,7 @@ optional ComputeWhileLoopTripCount(HloInstruction* while_op, HloEvaluator evaluator(/*max_loop_iterations=*/0); auto* while_init = while_op->mutable_operand(0); auto* indvar_init = while_init->mutable_operand(*indvar_tuple_idx); - StatusOr> indvar_init_result = - evaluator.Evaluate(indvar_init); + StatusOr indvar_init_result = evaluator.Evaluate(indvar_init); if (!indvar_init_result.ok()) { VLOG(2) << "Couldn't evaluate induction variable init: " << indvar_init_result.status(); @@ -197,32 +196,27 @@ optional ComputeWhileLoopTripCount(HloInstruction* while_op, auto* while_body_indvar = NonConstantOperand(while_body_indvar_update); // The initial value of the induction variable. - std::unique_ptr indvar_iter_val = - std::move(indvar_init_result).ValueOrDie(); + Literal indvar_iter_val = std::move(indvar_init_result).ValueOrDie(); for (int64 trip_count = 0; trip_count != max_value_returned + 1; ++trip_count) { auto* while_cond = while_op->while_condition(); auto* while_cond_root = while_cond->root_instruction(); auto* while_cond_indvar = NonConstantOperand(while_cond_root); - StatusOr> result = - evaluator.EvaluateWithSubstitutions( - while_cond_root, {{while_cond_indvar, indvar_iter_val.get()}}); + StatusOr result = evaluator.EvaluateWithSubstitutions( + while_cond_root, {{while_cond_indvar, &indvar_iter_val}}); if (!result.ok()) { VLOG(2) << "Couldn't evaluate while cond: " << result.status(); return nullopt; } - if (result.ValueOrDie()->data() == - tensorflow::gtl::ArraySlice{false}) { + if (result.ValueOrDie().data() == absl::Span{false}) { VLOG(2) << "Loop has static trip count of " << trip_count; return trip_count; } // Calculate the value of the induction variable after one iteration of the // loop, and check whether the while condition is true with this new value. - StatusOr> indvar_next_result = - evaluator.EvaluateWithSubstitutions( - while_body_indvar_update, - {{while_body_indvar, indvar_iter_val.get()}}); + StatusOr indvar_next_result = evaluator.EvaluateWithSubstitutions( + while_body_indvar_update, {{while_body_indvar, &indvar_iter_val}}); if (!indvar_next_result.ok()) { VLOG(2) << "Couldn't evaluate induction variable update: " << indvar_next_result.status(); diff --git a/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc b/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc index aab11806621746141f4302f39a780fcdbab99fc1..067cfcc17d65860a249de4d9e31703df12091d3a 100644 --- a/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc +++ b/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc @@ -15,10 +15,9 @@ limitations under the License. #include "tensorflow/compiler/xla/service/while_loop_constant_sinking.h" #include "absl/algorithm/container.h" +#include "absl/container/inlined_vector.h" #include "tensorflow/compiler/xla/service/while_util.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/gtl/flatmap.h" -#include "tensorflow/core/lib/gtl/inlined_vector.h" namespace xla { diff --git a/tensorflow/compiler/xla/service/while_loop_constant_sinking.h b/tensorflow/compiler/xla/service/while_loop_constant_sinking.h index 2dba7d7f7574742a301e3503e353bbe57d72a203..577bad6c7062d2ee40271e407e8eed7655fa13bf 100644 --- a/tensorflow/compiler/xla/service/while_loop_constant_sinking.h +++ b/tensorflow/compiler/xla/service/while_loop_constant_sinking.h @@ -50,7 +50,7 @@ namespace xla { // conditions as well. // // TODO(b/79121449): We should also sink broadcasts of constants. -class WhileLoopConstantSinking : public HloPassInterface { +class WhileLoopConstantSinking : public HloModulePass { public: ~WhileLoopConstantSinking() override = default; diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc index f4098f28b3d5cce3bb0bfc0a2ec5a05928366930..9795b2830b6d9add82b89ac76b5438ddc3d2bfe8 100644 --- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc +++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc @@ -15,18 +15,18 @@ limitations under the License. #include "tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h" #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" #include "tensorflow/compiler/xla/service/tuple_util.h" #include "tensorflow/compiler/xla/service/while_util.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/gtl/flatmap.h" -#include "tensorflow/core/lib/gtl/flatset.h" namespace xla { +using absl::flat_hash_map; +using absl::flat_hash_set; using absl::InlinedVector; -using tensorflow::gtl::FlatMap; -using tensorflow::gtl::FlatSet; // Copies `to_hoist` to the computation containing `while_instr`, hoisting its // operands as needed. All of its transitive operands are expected to be either @@ -34,8 +34,8 @@ using tensorflow::gtl::FlatSet; // function hoists the operands in `unhoisted_invariant_instructions` and moves // them into `hoisted_instructions`. static void CreateLoopInvariantCopy( - FlatMap* hoisted_instructions, - FlatSet* unhoisted_invariant_instructions, + flat_hash_map* hoisted_instructions, + flat_hash_set* unhoisted_invariant_instructions, HloInstruction* while_instr, HloInstruction* to_hoist) { HloComputation* parent_of_while = while_instr->parent(); HloComputation* while_body = while_instr->while_body(); @@ -110,6 +110,7 @@ bool WhileLoopInvariantCodeMotion::NotWorthHoistingIndividually( case HloOpcode::kBitcast: case HloOpcode::kBroadcast: + case HloOpcode::kIota: case HloOpcode::kReshape: case HloOpcode::kReverse: case HloOpcode::kSlice: @@ -146,13 +147,13 @@ WhileLoopInvariantCodeMotion::TryHoistingInvariantInstructionsFromWhileBody( // Maps instructions in the while body to instructions hoisted outside the // while that compute the same value. - FlatMap hoisted_instructions; + flat_hash_map hoisted_instructions; // Contains instructions that can be legally hoisted, but were deemed to be // unprofitable to be hoisted alone by NotWorthHoistingIndividually. When we // hoist an instruction in this set, we move it from // unhoisted_invariant_instructions to hoisted_instructions. - FlatSet unhoisted_invariant_instructions; + flat_hash_set unhoisted_invariant_instructions; // Invariant GTE's axiomatically satisfy the constraints for // unhoisted_invariant_instructions -- they can be legally hoisted, but there diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h index 2cdf20ce80362c0aeb9d8324573e7e9826cc018c..3031899f71e0fd77f20448d9d7489798af01615c 100644 --- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h +++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h @@ -25,7 +25,7 @@ namespace xla { // HLO pass that rewrites while loops to hoist loop invariant instructions in // the while body into the computation that contains the while instruction. -class WhileLoopInvariantCodeMotion : public HloPassInterface { +class WhileLoopInvariantCodeMotion : public HloModulePass { public: // If `hoist_constants` is true then constants are always hoisted out of while // loop bodies. Otherwise they are only hoisted out if they enable other diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc index e14014b961d44cf723e1363e27c19c2e149c9057..32e69c335b713c438bd7fcb2053709b0624f58ed 100644 --- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc @@ -28,10 +28,6 @@ namespace op = xla::testing::opcode_matchers; class WhileLoopInvariantCodeMotionTest : public HloVerifiedTestBase { public: - WhileLoopInvariantCodeMotionTest() - : HloVerifiedTestBase(/*layout_sensitive=*/false, - /*allow_mixed_precision=*/false) {} - // Makes a computation which has one parameter, of the given shape, and always // returns PRED[]{true}. This is useful as a dummy loop condition. HloComputation* MakeAlwaysTrueComputation(const Shape& param_shape, diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier.cc b/tensorflow/compiler/xla/service/while_loop_simplifier.cc index 6a7bfe3f129d97866ccc54897d584fab0f7c683e..630d71e5ca25e9d282ce6283284a32d6f725a193 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier.cc +++ b/tensorflow/compiler/xla/service/while_loop_simplifier.cc @@ -14,12 +14,13 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/while_loop_simplifier.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/types/optional.h" #include "tensorflow/compiler/xla/service/call_inliner.h" #include "tensorflow/compiler/xla/service/while_loop_analysis.h" -#include "tensorflow/core/lib/gtl/flatmap.h" namespace xla { @@ -114,7 +115,7 @@ static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { return false; } - tensorflow::gtl::FlatSet used_tuple_indices; + absl::flat_hash_set used_tuple_indices; for (HloComputation* comp : {while_body, while_cond}) { // The HLO verifier ensures that while_input's shape matches while_init's // shape, which we verified above is a tuple. @@ -181,7 +182,7 @@ static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { used_tuple_indices.end()); std::sort(new_to_old_tuple_idx.begin(), new_to_old_tuple_idx.end()); - tensorflow::gtl::FlatMap old_to_new_tuple_idx; + absl::flat_hash_map old_to_new_tuple_idx; for (int64 new_idx = 0; new_idx < new_to_old_tuple_idx.size(); ++new_idx) { int64 old_idx = new_to_old_tuple_idx[new_idx]; old_to_new_tuple_idx[old_idx] = new_idx; @@ -252,7 +253,7 @@ static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { // Create the new while condition, body, and init value. std::unique_ptr new_while_cond = while_cond->CloneWithReplacements( - make_while_computation_replacements(while_cond)); + make_while_computation_replacements(while_cond), /*extras=*/{}); std::unordered_map> while_body_replacements = make_while_computation_replacements(while_body); @@ -265,7 +266,8 @@ static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { while_body_replacements.emplace( while_body_root, HloInstruction::CreateTuple(new_while_body_root_elems)); std::unique_ptr new_while_body = - while_body->CloneWithReplacements(std::move(while_body_replacements)); + while_body->CloneWithReplacements(std::move(while_body_replacements), + /*extras=*/{}); // Add a new while_init instruction that repackages the old while_init // instruction's elements. We rely on the AlgebraicSimplifier and DCE to @@ -404,7 +406,7 @@ static StatusOr TryPropagateConstant(HloInstruction* while_op) { // build a map from the tuple element index to the constant value. Limit this // to scalar constant values because propagating array constants can regress // performance by forcing us to copy constants. - tensorflow::gtl::FlatMap index_to_constant; + absl::flat_hash_map index_to_constant; for (int i = 0; i < root_operands.size(); i++) { HloInstruction* instr = root_operands[i]; if (instr->opcode() == HloOpcode::kGetTupleElement && diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier.h b/tensorflow/compiler/xla/service/while_loop_simplifier.h index 78024f14dc89ff40a11bbc3602072fda1fe6f312..0bc5a0107bbcfb3b29a01d593fb79b89a863e49b 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier.h +++ b/tensorflow/compiler/xla/service/while_loop_simplifier.h @@ -30,7 +30,7 @@ namespace xla { // - Elements of a while loop's tuple that the loop doesn't use are removed // from the tuple. // -class WhileLoopSimplifier : public HloPassInterface { +class WhileLoopSimplifier : public HloModulePass { public: ~WhileLoopSimplifier() override {} absl::string_view name() const override { return "simplify-while-loops"; } diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc index cfe4104f6d0afbb2a1c31aaf94ec53a0ba5e178e..1c892ba179ec67ccc9dbfe93d925551d6977ba15 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc @@ -28,11 +28,6 @@ namespace { namespace op = xla::testing::opcode_matchers; class WhileLoopSimplifierTest : public HloVerifiedTestBase { - public: - WhileLoopSimplifierTest() - : HloVerifiedTestBase(/*layout_sensitive=*/false, - /*allow_mixed_precision=*/false) {} - protected: // Makes an HloModule that contains a loop with `num_iters` iteration. void MakeModuleWithSimpleLoop(int num_iters); diff --git a/tensorflow/compiler/xla/service/while_util.cc b/tensorflow/compiler/xla/service/while_util.cc index e8f76ff745a7871cd75294ff63c336cf1ce36f19..f90ac91f9d07aded8cafccf82dae894c9a149bd1 100644 --- a/tensorflow/compiler/xla/service/while_util.cc +++ b/tensorflow/compiler/xla/service/while_util.cc @@ -94,7 +94,7 @@ WidenWhileBody(HloComputation* narrow_body, const Shape& wide_shape) { /*static*/ StatusOr WhileUtil::MakeInstructionsLiveIn( HloInstruction* while_instr, - tensorflow::gtl::ArraySlice instructions) { + absl::Span instructions) { CHECK(ShapeUtil::IsTuple(while_instr->shape())); int64 elements_in_old_while_shape = while_instr->shape().tuple_shapes_size(); diff --git a/tensorflow/compiler/xla/service/while_util.h b/tensorflow/compiler/xla/service/while_util.h index e67636d80f4b682fe1335eae535fb86105ac082b..b1c4486887ae0ddbe2ba4e79f45a265689111017 100644 --- a/tensorflow/compiler/xla/service/while_util.h +++ b/tensorflow/compiler/xla/service/while_util.h @@ -55,7 +55,7 @@ class WhileUtil { // that contains `while_instr`. static StatusOr MakeInstructionsLiveIn( HloInstruction* while_instr, - tensorflow::gtl::ArraySlice instructions); + absl::Span instructions); using LoopStateTy = std::vector; using LoopBodyGeneratorTy = std::function( diff --git a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h index a7f0e207eb5a81b04bb28977d6f5e38864ad2d6a..87294120d51d244d9f2649cf95916f022bf829cb 100644 --- a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h +++ b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h @@ -21,7 +21,7 @@ limitations under the License. // HLO pass that replaces zero sized Hlos with a zero sized constant literal. namespace xla { -class ZeroSizedHloElimination : public HloPassInterface { +class ZeroSizedHloElimination : public HloModulePass { public: StatusOr Run(HloModule* module) override; absl::string_view name() const override { diff --git a/tensorflow/compiler/xla/shape_layout.cc b/tensorflow/compiler/xla/shape_layout.cc index caad31d6ce7ce35fa362ec364b0d7f1d95973715..d44db89d571891ecef554cd45c050017833982bb 100644 --- a/tensorflow/compiler/xla/shape_layout.cc +++ b/tensorflow/compiler/xla/shape_layout.cc @@ -25,8 +25,8 @@ namespace xla { Status ShapeLayout::CopyLayoutFromShape(const Shape& other_shape) { if (!ShapeUtil::Compatible(other_shape, shape_)) { return InvalidArgument("Shape %s is not compatible with shape %s", - ShapeUtil::HumanString(other_shape).c_str(), - ShapeUtil::HumanString(shape()).c_str()); + ShapeUtil::HumanString(other_shape), + ShapeUtil::HumanString(shape())); } shape_ = other_shape; return Status::OK(); @@ -35,8 +35,8 @@ Status ShapeLayout::CopyLayoutFromShape(const Shape& other_shape) { Status ShapeLayout::AssignLayoutToShape(Shape* to_shape) const { if (!ShapeUtil::Compatible(*to_shape, shape_)) { return InvalidArgument("Shape %s is not compatible with shape %s", - ShapeUtil::HumanString(*to_shape).c_str(), - ShapeUtil::HumanString(shape()).c_str()); + ShapeUtil::HumanString(*to_shape), + ShapeUtil::HumanString(shape())); } *to_shape = shape_; return Status::OK(); diff --git a/tensorflow/compiler/xla/shape_tree.h b/tensorflow/compiler/xla/shape_tree.h index c793a39c272154dfcc0d9c400d9642a567816dec..df610102b4c7fa08c0b7030124939009130f89f4 100644 --- a/tensorflow/compiler/xla/shape_tree.h +++ b/tensorflow/compiler/xla/shape_tree.h @@ -23,13 +23,13 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/types/optional.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/gtl/iterator_range.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -224,14 +224,13 @@ class ShapeTree { // REQUIRES: index must exist in the ShapeTree. iterator find(ShapeIndexView index) { Node* element = Lookup(index); - return iterator(&nodes_, typename std::vector::iterator(element), - /*iterate_leaves_only=*/false); + auto element_iter = nodes_.begin() + (element - &nodes_[0]); + return iterator(&nodes_, element_iter, /*iterate_leaves_only=*/false); } const_iterator find(ShapeIndexView index) const { Node* element = Lookup(index); - return iterator(&nodes_, - typename std::vector::const_iterator(element), - /*iterate_leaves_only=*/false); + auto element_iter = nodes_.cbegin() + (element - &nodes_[0]); + return const_iterator(&nodes_, element_iter, /*iterate_leaves_only=*/false); } // Returns the number of leaf nodes in the tree. @@ -262,6 +261,25 @@ class ShapeTree { template Status ForEachMutableElementWithStatus(const Fn& func); + // Maps each element to generate a new tree with the same shape. + template + ShapeTree Map(const std::function& func) { + ShapeTree result(shape_storage_); + ForEachElement([&](const ShapeIndex& index, const T& t) { + *result.mutable_element(index) = func(t); + }); + return result; + } + + template + ShapeTree Map(const std::function& func) { + ShapeTree result(shape_storage_); + ForEachMutableElement([&](const ShapeIndex& index, T* t) { + *result.mutable_element(index) = func(t); + }); + return result; + } + // Copy the subtree of values from 'other' rooted at ShapeIndex // 'source_base_index' into the subtree of value in this ShapeTree rooted at // 'target_base_index'. @@ -463,9 +481,6 @@ template ShapeTree::ShapeTree(Shape shape) : shape_storage_(std::make_shared(std::move(shape))), shape_(shape_storage_.get()) { - // The shape_ field is just used to hold the structure of the shape. - // It should not be relied upon to store layout information. - LayoutUtil::ClearLayout(shape_storage_.get()); const int64 count = CountSubshapes(*shape_); nodes_.reserve(count); nodes_.emplace_back(ShapeIndex{}); @@ -502,9 +517,6 @@ template ShapeTree::ShapeTree(Shape shape, const T& init_value) : shape_storage_(std::make_shared(std::move(shape))), shape_(shape_storage_.get()) { - // The shape_ field is just used to hold the structure of the shape. - // It should not be relied upon to store layout information. - LayoutUtil::ClearLayout(shape_storage_.get()); const int64 count = CountSubshapes(*shape_); nodes_.reserve(count); nodes_.emplace_back(ShapeIndex{}, init_value); diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index 31ddd57eef5110141b04ff5c239007877220085b..7a34c0fb2641db3062337f9abf33b09a817f5bf5 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -95,11 +95,11 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts, } if (ShapeUtil::IsTuple(lhs)) { - return ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(), - [=](const Shape& l, const Shape& r) { - return CompareShapes(l, r, compare_layouts, - ignore_fp_precision); - }); + return absl::c_equal(lhs.tuple_shapes(), rhs.tuple_shapes(), + [=](const Shape& l, const Shape& r) { + return CompareShapes(l, r, compare_layouts, + ignore_fp_precision); + }); } else if (!ShapeUtil::IsArray(lhs)) { // Non-tuple, non-array tupes such as opaque and token types are trivially // the same. @@ -111,13 +111,13 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts, return false; } if (LayoutUtil::IsDenseArray(lhs)) { - if (!ContainersEqual(LayoutUtil::MinorToMajor(lhs), - LayoutUtil::MinorToMajor(rhs))) { + if (!absl::c_equal(LayoutUtil::MinorToMajor(lhs), + LayoutUtil::MinorToMajor(rhs))) { VLOG(3) << "CompareShapes: lhs layout != rhs layout"; return false; } - if (!ContainersEqual(lhs.layout().padded_dimensions(), - rhs.layout().padded_dimensions())) { + if (!absl::c_equal(lhs.layout().padded_dimensions(), + rhs.layout().padded_dimensions())) { VLOG(3) << "CompareShapes: lhs padded_dimensions != rhs padded_dimensions"; return false; @@ -139,15 +139,15 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts, // Constructs and returns the new shape with the given minor_to_major order in // its Layout. StatusOr MakeShapeWithLayoutInternal( - PrimitiveType element_type, tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice minor_to_major) { + PrimitiveType element_type, absl::Span dimensions, + absl::Span minor_to_major) { if (dimensions.size() != minor_to_major.size()) { return InvalidArgument("Dimensions size is %ld, but layout size is %ld.", dimensions.size(), minor_to_major.size()); } if (element_type == OPAQUE || element_type == TUPLE) { return InvalidArgument("Unsupported element type: %s", - PrimitiveType_Name(element_type).c_str()); + PrimitiveType_Name(element_type)); } Shape shape = ShapeUtil::MakeShape(element_type, dimensions); auto min2maj = shape.mutable_layout()->mutable_minor_to_major(); @@ -214,8 +214,8 @@ StatusOr MakeShapeWithLayoutInternal( return program_shape; } -/* static */ Shape ShapeUtil::MakeShape( - PrimitiveType element_type, tensorflow::gtl::ArraySlice dimensions) { +/* static */ Shape ShapeUtil::MakeShape(PrimitiveType element_type, + absl::Span dimensions) { CHECK(IsArrayPrimitiveType(element_type)); Shape result; PopulateShape(element_type, dimensions, &result); @@ -223,21 +223,21 @@ StatusOr MakeShapeWithLayoutInternal( } /* static */ Shape ShapeUtil::MakeShapeWithLayout( - PrimitiveType element_type, tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice minor_to_major) { + PrimitiveType element_type, absl::Span dimensions, + absl::Span minor_to_major) { return MakeShapeWithLayoutInternal(element_type, dimensions, minor_to_major) .ValueOrDie(); } /* static */ Shape ShapeUtil::MakeShapeWithDescendingLayout( - PrimitiveType element_type, tensorflow::gtl::ArraySlice dimensions) { + PrimitiveType element_type, absl::Span dimensions) { std::vector layout(dimensions.size()); std::iota(layout.rbegin(), layout.rend(), static_cast(0)); return MakeShapeWithLayout(element_type, dimensions, layout); } /* static */ Shape ShapeUtil::MakeShapeWithSparseLayout( - PrimitiveType element_type, tensorflow::gtl::ArraySlice dimensions, + PrimitiveType element_type, absl::Span dimensions, int64 max_sparse_elements) { CHECK(IsArrayPrimitiveType(element_type)); Shape shape = ShapeUtil::MakeShape(element_type, dimensions); @@ -256,9 +256,9 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( return MakeShapeWithDescendingLayout(shape.element_type(), dims); } -/* static */ void ShapeUtil::PopulateShape( - PrimitiveType element_type, tensorflow::gtl::ArraySlice dimensions, - Shape* shape) { +/* static */ void ShapeUtil::PopulateShape(PrimitiveType element_type, + absl::Span dimensions, + Shape* shape) { shape->Clear(); shape->set_element_type(element_type); for (int64 dimension : dimensions) { @@ -268,8 +268,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( TF_DCHECK_OK(ValidateShape(*shape)); } -/* static */ Shape ShapeUtil::MakeTupleShape( - tensorflow::gtl::ArraySlice shapes) { +/* static */ Shape ShapeUtil::MakeTupleShape(absl::Span shapes) { Shape result; result.set_element_type(TUPLE); result.mutable_tuple_shapes()->Reserve(shapes.size()); @@ -423,8 +422,11 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( } /* static */ int64 ShapeUtil::ElementsIn(const Shape& shape) { - CHECK(IsArray(shape)) << ShapeUtil::HumanString(shape); - CHECK_EQ(shape.dimensions_size(), Rank(shape)); + DCHECK(IsArray(shape)) << ShapeUtil::HumanString(shape); + DCHECK_EQ(shape.dimensions_size(), Rank(shape)); + if (shape.dimensions().size() == 1) { + return shape.dimensions()[0]; + } return std::accumulate( shape.dimensions().begin(), shape.dimensions().end(), 1LL, std::multiplies()); @@ -442,12 +444,26 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( return count; } +/* static */ bool ShapeUtil::HasPrimitiveType(const Shape& shape, + PrimitiveType primitive_type) { + if (shape.element_type() == primitive_type) { + return true; + } + for (const Shape& element_shape : shape.tuple_shapes()) { + if (HasPrimitiveType(element_shape, primitive_type)) { + return true; + } + } + return false; +} + /* static */ bool ShapeUtil::IsZeroElementArray(const Shape& shape) { return ShapeUtil::IsArray(shape) && ElementsIn(shape) == 0; } -/* static */ bool ShapeUtil::IsScalarF32(const Shape& shape) { - return shape.element_type() == F32 && Rank(shape) == 0; +/* static */ bool ShapeUtil::IsScalarWithElementType( + const Shape& shape, PrimitiveType element_type) { + return IsScalar(shape) && shape.element_type() == element_type; } namespace { @@ -491,8 +507,7 @@ StatusOr StringToPrimitiveType(const string& name) { }(); auto found = name_to_type->find(name); if (found == name_to_type->end()) { - return InvalidArgument("Invalid element type string: \"%s\".", - name.c_str()); + return InvalidArgument("Invalid element type string: \"%s\".", name); } return found->second; } @@ -564,8 +579,7 @@ StatusOr ParseShapeStringInternal(absl::string_view* s) { if (absl::ConsumePrefix(s, ")")) { break; } else if (must_end) { - return InvalidArgument("Expected end of tuple; got: \"%s\"", - string(*s).c_str()); + return InvalidArgument("Expected end of tuple; got: \"%s\"", *s); } shapes.emplace_back(); TF_ASSIGN_OR_RETURN(shapes.back(), ParseShapeStringInternal(s)); @@ -583,7 +597,8 @@ StatusOr ParseShapeStringInternal(absl::string_view* s) { // we convert in to the RE2-consumable type and then consume the corresponding // amount from our string_view type. static LazyRE2 shape_pattern = { - "^(\\w*\\d*)\\[([\\d,]*)\\](?:\\s*(dense|sparse)?\\s*{([\\d,]+)})?"}; + "^(\\w*\\d*)\\[([\\d,\\s]*)\\](?:\\s*(dense|sparse)?\\s*{([\\d,\\s]+)})" + "?"}; tensorflow::RegexpStringPiece s_consumable(s->data(), s->size()); if (RE2::Consume(&s_consumable, *shape_pattern, &element_type_string, &dimensions_string, &format_string, &layout_string)) { @@ -593,8 +608,8 @@ StatusOr ParseShapeStringInternal(absl::string_view* s) { int64 element; if (!absl::SimpleAtoi(input, &element)) { return InvalidArgument( - "Invalid s64 value in parsed shape string: \"%s\" in \"%s\"", - string(input).c_str(), string(*s).c_str()); + "Invalid s64 value in parsed shape string: \"%s\" in \"%s\"", input, + *s); } return element; }; @@ -618,7 +633,7 @@ StatusOr ParseShapeStringInternal(absl::string_view* s) { StringToPrimitiveType(element_type_string)); if (primitive_type == PRIMITIVE_TYPE_INVALID || primitive_type == TUPLE) { return InvalidArgument("Invalid element type string: \"%s\".", - element_type_string.c_str()); + element_type_string); } Shape result; @@ -648,16 +663,14 @@ StatusOr ParseShapeStringInternal(absl::string_view* s) { return std::move(result); } - return InvalidArgument("Invalid shape string to parse: \"%s\"", - string(*s).c_str()); + return InvalidArgument("Invalid shape string to parse: \"%s\"", *s); } } // namespace /* static */ StatusOr ShapeUtil::ParseShapeString(absl::string_view s) { TF_ASSIGN_OR_RETURN(Shape shape, ParseShapeStringInternal(&s)); if (!s.empty()) { - return InvalidArgument("Invalid shape string to parse: \"%s\"", - string(s).c_str()); + return InvalidArgument("Invalid shape string to parse: \"%s\"", s); } return shape; } @@ -666,7 +679,7 @@ StatusOr ParseShapeStringInternal(absl::string_view* s) { const Shape& rhs) { CHECK(ShapeUtil::IsArray(lhs)); CHECK(ShapeUtil::IsArray(rhs)); - return ContainersEqual(lhs.dimensions(), rhs.dimensions()); + return absl::c_equal(lhs.dimensions(), rhs.dimensions()); } /* static */ bool ShapeUtil::Compatible(const Shape& lhs, const Shape& rhs) { @@ -680,8 +693,8 @@ StatusOr ParseShapeStringInternal(absl::string_view* s) { return IsArray(rhs) && SameDimensions(lhs, rhs); } else if (lhs.element_type() == TUPLE) { return rhs.element_type() == TUPLE && - ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(), - CompatibleIgnoringElementType); + absl::c_equal(lhs.tuple_shapes(), rhs.tuple_shapes(), + CompatibleIgnoringElementType); } else { // Opaque, token, etc types are vacuously compatible. return lhs.element_type() == rhs.element_type(); @@ -695,8 +708,8 @@ StatusOr ParseShapeStringInternal(absl::string_view* s) { CompatibleIgnoringElementType(lhs, rhs); } else if (lhs.element_type() == TUPLE) { return rhs.element_type() == TUPLE && - ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(), - CompatibleIgnoringFpPrecision); + absl::c_equal(lhs.tuple_shapes(), rhs.tuple_shapes(), + CompatibleIgnoringFpPrecision); } else { // Opaque, token, etc types are vacuously compatible. return lhs.element_type() == rhs.element_type(); @@ -795,7 +808,7 @@ StatusOr ParseShapeStringInternal(absl::string_view* s) { allocated_element_count = LayoutUtil::MaxSparseElements(shape.layout()); } else { CHECK(LayoutUtil::IsDenseArray(shape)) << shape.ShortDebugString(); - tensorflow::gtl::ArraySlice padded_dimensions = + absl::Span padded_dimensions = LayoutUtil::PaddedDimensions(shape); if (!padded_dimensions.empty()) { CHECK_EQ(Rank(shape), padded_dimensions.size()); @@ -820,9 +833,10 @@ StatusOr ParseShapeStringInternal(absl::string_view* s) { /* static */ Status ShapeUtil::ValidateShapeWithOptionalLayoutInternal( const Shape& shape) { - if (shape.element_type() == PRIMITIVE_TYPE_INVALID) { + if (shape.element_type() == PRIMITIVE_TYPE_INVALID || + !PrimitiveType_IsValid(shape.element_type())) { return InvalidArgument("shape has invalid element type: %s", - shape.ShortDebugString().c_str()); + shape.ShortDebugString()); } if (shape.element_type() == TUPLE) { if (shape.dimensions_size() != 0) { @@ -845,31 +859,27 @@ StatusOr ParseShapeStringInternal(absl::string_view* s) { if (shape.dimensions_size() != 0) { return InvalidArgument( "shape has %s element type, but has dimensions field: %s", - LowercasePrimitiveTypeName(shape.element_type()).c_str(), - shape.ShortDebugString().c_str()); + LowercasePrimitiveTypeName(shape.element_type()), + shape.ShortDebugString()); } if (shape.has_layout()) { return InvalidArgument( "shape has %s element type, but has layout field: %s", - LowercasePrimitiveTypeName(shape.element_type()).c_str(), - shape.ShortDebugString().c_str()); + LowercasePrimitiveTypeName(shape.element_type()), + shape.ShortDebugString()); } return Status::OK(); } - if (Rank(shape) != shape.dimensions_size()) { - return InvalidArgument( - "shape's rank is mismatched with dimension count; rank=%lld " - "dimensions_size=%d", - Rank(shape), shape.dimensions_size()); + if (LayoutUtil::IsSparseArray(shape) && Rank(shape) == 0) { + return InvalidArgument("sparse arrays must have rank > 0"); } for (int64 i = 0; i < Rank(shape); ++i) { int64 dimension = shape.dimensions(i); if (dimension < 0) { return InvalidArgument( - "shape's dimensions must not be < 0; dimension at index %lld was " - "%lld", - i, dimension); + "shape's dimensions must not be < 0; dimension at index %d was %d", i, + dimension); } } @@ -934,7 +944,7 @@ StatusOr ParseShapeStringInternal(absl::string_view* s) { if (shape_size < 0) { return InvalidArgument("Shape %s size may overflow int64.", - ShapeUtil::HumanString(shape).c_str()); + ShapeUtil::HumanString(shape)); } VLOG(3) << "Shape size is valid: " << shape_size; @@ -994,7 +1004,7 @@ StatusOr ParseShapeStringInternal(absl::string_view* s) { i >= return_shape->tuple_shapes_size()) { return InvalidArgument( "Shape index %s not a valid subshape index for tuple with shape %s", - index.ToString().c_str(), shape.DebugString().c_str()); + index.ToString(), shape.DebugString()); } return_shape = &return_shape->tuple_shapes(i); } @@ -1040,7 +1050,7 @@ bool ShapeUtil::IsLeafIndex(const Shape& shape, const ShapeIndex& index) { /* static */ bool ShapeUtil::HasDegenerateDimensions(const Shape& shape) { CHECK(ShapeUtil::IsArray(shape)); - return ArrayContains(AsInt64Slice(shape.dimensions()), 1); + return absl::c_linear_search(shape.dimensions(), 1); } namespace { @@ -1120,7 +1130,7 @@ Status ForEachMutableSubshapeHelper( } /* static */ Shape ShapeUtil::PermuteDimensions( - tensorflow::gtl::ArraySlice permutation, const Shape& shape) { + absl::Span permutation, const Shape& shape) { Shape new_shape = shape; new_shape.clear_dimensions(); for (auto dim : Permute(permutation, shape.dimensions())) { @@ -1264,7 +1274,7 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, /* static */ bool ShapeUtil::TransposeIsBitcast( const Shape& input_shape, const Shape& output_shape, - tensorflow::gtl::ArraySlice dimension_mapping) { + absl::Span dimension_mapping) { CHECK(LayoutUtil::HasLayout(input_shape) && LayoutUtil::HasLayout(output_shape)); @@ -1291,7 +1301,7 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, // apply(input_dimensions, I) = // apply((dimension_mapping * output_dimensions), I) // input_dimensions = dimension_mapping * output_dimensions - return ContainersEqual( + return absl::c_equal( ComposePermutations(dimension_mapping, AsInt64Slice(output_shape.layout().minor_to_major())), input_shape.layout().minor_to_major()); @@ -1637,7 +1647,7 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, } std::ostream& operator<<(std::ostream& out, const Shape& shape) { - out << ShapeUtil::HumanString(shape); + out << ShapeUtil::HumanStringWithLayout(shape); return out; } diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index 84f36e48a0fb930958dfc13732bf15225eebb1ed..51cedce7f0e13e65dfd0e250689e0ecd30f971dc 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -22,8 +22,10 @@ limitations under the License. #include #include +#include "absl/base/macros.h" #include "absl/container/inlined_vector.h" #include "absl/types/optional.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -32,7 +34,6 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/threadpool.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/cpu_info.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/macros.h" @@ -71,7 +72,7 @@ class ShapeIndex { void push_back(int64 value) { indices_.push_back(value); } void pop_back() { indices_.pop_back(); } - // push_front is O(n^2), but shapes don't usually have a ton of dimensions. + // push_front is O(n), but shapes don't usually have a ton of dimensions. void push_front(int64 value) { indices_.insert(indices_.begin(), value); } using container_type = absl::InlinedVector; @@ -131,12 +132,12 @@ class ShapeIndexView { } ShapeIndexView ConsumeFront() const { ShapeIndexView result = *this; - result.indices_.pop_front(); + result.indices_.remove_prefix(1); return result; } ShapeIndexView ConsumeBack() const { ShapeIndexView result = *this; - result.indices_.pop_back(); + result.indices_.remove_suffix(1); return result; } ShapeIndex ToShapeIndex() const { return ShapeIndex(begin(), end()); } @@ -147,7 +148,7 @@ class ShapeIndexView { string ToString() const; private: - tensorflow::gtl::ArraySlice indices_; + absl::Span indices_; }; std::ostream& operator<<(std::ostream& out, const ShapeIndex& shape_index); @@ -180,6 +181,10 @@ class ShapeUtil { // As ElementsIn(), but recurses through tuples. static int64 ElementsInRecursive(const Shape& shape); + // Returns true if shape has the primitive type, recurses through tuples. + static bool HasPrimitiveType(const Shape& shape, + PrimitiveType primitive_type); + // Returns true if 'shape' is an array with zero elements. static bool IsZeroElementArray(const Shape& shape); @@ -307,7 +312,10 @@ class ShapeUtil { static bool IsEffectiveScalar(const Shape& shape) { return IsArray(shape) && TrueRank(shape) == 0; } - static bool IsScalarF32(const Shape& shape); + + // Returns whether "shape" is a scalar (array) with the given element_type. + static bool IsScalarWithElementType(const Shape& shape, + PrimitiveType element_type); // Extracts the size of the shape's dimension at dimension number // GetDimensionNumber(dimension_number). @@ -328,7 +336,7 @@ class ShapeUtil { static Shape ChangeElementType(const Shape& original, PrimitiveType type); // Creates a tuple shape from a slice of element shapes within the tuple. - static Shape MakeTupleShape(tensorflow::gtl::ArraySlice shapes); + static Shape MakeTupleShape(absl::Span shapes); // Creates an opaque shape. These are generally used for threading a context // into a custom operation. @@ -355,31 +363,29 @@ class ShapeUtil { // Constructs a new shape with the given element type and sequence of // dimensions. static Shape MakeShape(PrimitiveType element_type, - tensorflow::gtl::ArraySlice dimensions); + absl::Span dimensions); // Creates a Shape with element type corresponding to T and the given // dimensions template - static Shape MakeShapeWithType( - tensorflow::gtl::ArraySlice dimensions) { + static Shape MakeShapeWithType(absl::Span dimensions) { return ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType(), dimensions); } // Constructs a new shape with the given minor_to_major order in its Layout. // Returns a value shape such that shape.has_layout(). - static Shape MakeShapeWithLayout( - PrimitiveType element_type, tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice minor_to_major); + static Shape MakeShapeWithLayout(PrimitiveType element_type, + absl::Span dimensions, + absl::Span minor_to_major); - static Shape MakeShapeWithSparseLayout( - PrimitiveType element_type, tensorflow::gtl::ArraySlice dimensions, - int64 max_sparse_elements); + static Shape MakeShapeWithSparseLayout(PrimitiveType element_type, + absl::Span dimensions, + int64 max_sparse_elements); // Constructs a new shape with major-first layout (i.e. {n, n-1, ..., 0}). static Shape MakeShapeWithDescendingLayout( - PrimitiveType element_type, - tensorflow::gtl::ArraySlice dimensions); + PrimitiveType element_type, absl::Span dimensions); // Returns a new Shape based on the given Shape with low-dimension-major // layout (i.e. {n, n-1, ..., 0}, like Fortran), and with the dimensions @@ -391,8 +397,7 @@ class ShapeUtil { // As MakeShape, but the object to write to is passed in. static void PopulateShape(PrimitiveType element_type, - tensorflow::gtl::ArraySlice dimensions, - Shape* shape); + absl::Span dimensions, Shape* shape); // Validates that the provided shape satisfies invariants. static Status ValidateShape(const Shape& shape); @@ -478,8 +483,7 @@ class ShapeUtil { // Shorthand for testing whether a shape is of a given element type and // sequence of dimensions. - // - // DEPRECATED: Use Equal() instead. + ABSL_DEPRECATED("Use Equal() instead.") static bool ShapeIs(const Shape& shape, PrimitiveType element_type, std::initializer_list dimensions); @@ -539,7 +543,7 @@ class ShapeUtil { // !HasLayout(shape) || // TransposeIsBitcast(shape, PermuteDimensions(permutation, shape), // InversePermutation(permutation)). - static Shape PermuteDimensions(tensorflow::gtl::ArraySlice permutation, + static Shape PermuteDimensions(absl::Span permutation, const Shape& shape); // If we can go from `shape_pre` to `shape_post` by merely inserting or @@ -580,9 +584,9 @@ class ShapeUtil { // to its input and thus may be replaced with a bitcast. // // Precondition: Both input_shape and output_shape have explicit layouts. - static bool TransposeIsBitcast( - const Shape& input_shape, const Shape& output_shape, - tensorflow::gtl::ArraySlice dimension_mapping); + static bool TransposeIsBitcast(const Shape& input_shape, + const Shape& output_shape, + absl::Span dimension_mapping); // Returns whether a reshape from "input_shape" to "output_shape" is a // bitcast. @@ -621,12 +625,12 @@ class ShapeUtil { // continue, or false otherwise. // // visitor_function must be a callable of type - // StatusOr(ArraySlice) or compatible. + // StatusOr(Span) or compatible. template static Status ForEachIndexWithStatus(const Shape& shape, - tensorflow::gtl::ArraySlice base, - tensorflow::gtl::ArraySlice count, - tensorflow::gtl::ArraySlice incr, + absl::Span base, + absl::Span count, + absl::Span incr, const FnType& visitor_function) { return ForEachIndexInternal(shape, base, count, incr, visitor_function); } @@ -648,13 +652,12 @@ class ShapeUtil { } template - static void ForEachIndex(const Shape& shape, - tensorflow::gtl::ArraySlice base, - tensorflow::gtl::ArraySlice count, - tensorflow::gtl::ArraySlice incr, + static void ForEachIndex(const Shape& shape, absl::Span base, + absl::Span count, + absl::Span incr, const FnType& visitor_function) { ForEachIndexWithStatus(shape, base, count, incr, - [&](tensorflow::gtl::ArraySlice indices) { + [&](absl::Span indices) { return StatusOr(visitor_function(indices)); }) .IgnoreError(); @@ -676,7 +679,7 @@ class ShapeUtil { template static void ForEachIndex(const Shape& shape, const FnType& visitor_function) { ForEachIndexWithStatus(shape, - [&](tensorflow::gtl::ArraySlice indices) { + [&](absl::Span indices) { return StatusOr(visitor_function(indices)); }) .IgnoreError(); @@ -687,18 +690,18 @@ class ShapeUtil { // matter. // // visitor_function must be a callable of type - // void(ArraySlice) or compatible. + // void(Span) or compatible. template static void ForEachIndexParallel(const Shape& shape, - tensorflow::gtl::ArraySlice base, - tensorflow::gtl::ArraySlice count, - tensorflow::gtl::ArraySlice incr, + absl::Span base, + absl::Span count, + absl::Span incr, const FnType& visitor_function) { // The parallel version of ForEachIndexInternal can never fail. CHECK(ForEachIndexInternal( shape, base, count, incr, - [&visitor_function](tensorflow::gtl::ArraySlice indexes) - -> StatusOr { + [&visitor_function]( + absl::Span indexes) -> StatusOr { visitor_function(indexes); return true; }, @@ -720,9 +723,9 @@ class ShapeUtil { template static Status ForEachIndexInternal(const Shape& shape, - tensorflow::gtl::ArraySlice base, - tensorflow::gtl::ArraySlice count, - tensorflow::gtl::ArraySlice incr, + absl::Span base, + absl::Span count, + absl::Span incr, const FnType& visitor_function, bool parallel = false) { if (ShapeUtil::IsZeroElementArray(shape)) { diff --git a/tensorflow/compiler/xla/shape_util_test.cc b/tensorflow/compiler/xla/shape_util_test.cc index 7549ba9c78025de06624f01d0e87956db27f4f9a..c622ecdca1fd66604d1a6ceaf705f2e70edaee55 100644 --- a/tensorflow/compiler/xla/shape_util_test.cc +++ b/tensorflow/compiler/xla/shape_util_test.cc @@ -445,6 +445,22 @@ TEST(ShapeUtilTest, ElementsIn) { EXPECT_EQ(221, ShapeUtil::ElementsIn(ShapeUtil::MakeShape(S32, {13, 17}))); } +TEST(ShapeUtilTest, HasPrimitiveType) { + EXPECT_TRUE(ShapeUtil::HasPrimitiveType(ShapeUtil::MakeShape(S32, {}), S32)); + EXPECT_FALSE(ShapeUtil::HasPrimitiveType(ShapeUtil::MakeShape(S32, {}), S16)); + EXPECT_TRUE(ShapeUtil::HasPrimitiveType(ShapeUtil::MakeShape(S32, {0}), S32)); + EXPECT_FALSE(ShapeUtil::HasPrimitiveType(ShapeUtil::MakeTupleShape({}), S32)); + EXPECT_TRUE(ShapeUtil::HasPrimitiveType( + ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(S32, {})}), + S32)); + EXPECT_TRUE(ShapeUtil::HasPrimitiveType( + ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(S32, {}), + ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(S16, {})})}), + S16)); +} + TEST(ShapeUtilTest, IsZeroElementArray) { EXPECT_FALSE(ShapeUtil::IsZeroElementArray(ShapeUtil::MakeShape(S32, {}))); EXPECT_TRUE(ShapeUtil::IsZeroElementArray(ShapeUtil::MakeShape(S32, {0}))); @@ -705,11 +721,10 @@ TEST(ShapeUtilTest, ForEachIndex) { Shape shape = ShapeUtil::MakeShape(F32, data.dimensions); // Increments at every invocation. int invocations = 0; - auto increment_func = - [&invocations](tensorflow::gtl::ArraySlice indexes) { - invocations++; - return true; - }; + auto increment_func = [&invocations](absl::Span indexes) { + invocations++; + return true; + }; std::vector zero_base(data.dimensions.size(), 0); std::vector step(data.dimensions.size(), 1); @@ -726,8 +741,7 @@ TEST(ShapeUtilTest, ForEachIndexWithStatus) { // Increments at every invocation. int invocations = 0; auto increment_func = - [&invocations]( - tensorflow::gtl::ArraySlice indexes) -> StatusOr { + [&invocations](absl::Span indexes) -> StatusOr { if (++invocations == 5) { return Unimplemented("Cannot increment beyond 5."); } @@ -748,7 +762,7 @@ TEST(ShapeUtilTest, ForEachIndexParallel) { Shape shape = ShapeUtil::MakeShape(F32, {10, 10}); int64 output[10][10]; int init = 5; - auto set_func = [&](tensorflow::gtl::ArraySlice indexes) { + auto set_func = [&](absl::Span indexes) { output[indexes[0]][indexes[1]] = init + indexes[0] + indexes[1]; }; diff --git a/tensorflow/compiler/xla/sparse_index_array.cc b/tensorflow/compiler/xla/sparse_index_array.cc index 31844abd89a020c87c403353374a80fb639a3244..1c135dda864b3060b8bdc6369f18268d7c5c7f9e 100644 --- a/tensorflow/compiler/xla/sparse_index_array.cc +++ b/tensorflow/compiler/xla/sparse_index_array.cc @@ -33,7 +33,7 @@ SparseIndexArray::SparseIndexArray(int64 max_indices, int64 rank, } SparseIndexArray::SparseIndexArray(int64 max_indices, int64 rank, - tensorflow::gtl::ArraySlice indices) + absl::Span indices) : SparseIndexArray(max_indices, rank, std::vector(indices.begin(), indices.end())) {} @@ -48,25 +48,24 @@ int64 SparseIndexArray::index_count() const { return indices_.size() / rank_; } -tensorflow::gtl::ArraySlice SparseIndexArray::At( +absl::Span SparseIndexArray::At( int64 sparse_element_number) const { CHECK_GT(rank_, 0); CHECK_GE(sparse_element_number, 0); CHECK_LE(rank_ * sparse_element_number + rank_, indices_.size()); - return tensorflow::gtl::ArraySlice( + return absl::Span( indices_.data() + rank_ * sparse_element_number, rank_); } -tensorflow::gtl::MutableArraySlice SparseIndexArray::At( - int64 sparse_element_number) { +absl::Span SparseIndexArray::At(int64 sparse_element_number) { CHECK_GT(rank_, 0); CHECK_GE(sparse_element_number, 0); CHECK_LE(rank_ * sparse_element_number + rank_, indices_.size()); - return tensorflow::gtl::MutableArraySlice( - indices_.data() + rank_ * sparse_element_number, rank_); + return absl::Span(indices_.data() + rank_ * sparse_element_number, + rank_); } -void SparseIndexArray::Append(tensorflow::gtl::ArraySlice index) { +void SparseIndexArray::Append(absl::Span index) { CHECK_GT(rank_, 0); CHECK_EQ(index.size(), rank_); indices_.insert(indices_.end(), index.begin(), index.end()); @@ -90,12 +89,12 @@ bool SparseIndexArray::Validate(const Shape& shape) const { if (num_indices < 2) { return true; } - tensorflow::gtl::ArraySlice last = At(0); + absl::Span last = At(0); if (!IndexUtil::IndexInBounds(shape, last)) { return false; } for (int64 n = 1; n < num_indices; ++n) { - tensorflow::gtl::ArraySlice next = At(n); + absl::Span next = At(n); if (!IndexUtil::IndexInBounds(shape, next)) { return false; } diff --git a/tensorflow/compiler/xla/sparse_index_array.h b/tensorflow/compiler/xla/sparse_index_array.h index 70fab3bea5d346f3f8f6a2e52267696934dc5990..a96d483462efd77ae4761541e8c79b2c84fa49f3 100644 --- a/tensorflow/compiler/xla/sparse_index_array.h +++ b/tensorflow/compiler/xla/sparse_index_array.h @@ -21,10 +21,10 @@ limitations under the License. #include #include "absl/container/inlined_vector.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/index_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" namespace xla { @@ -65,7 +65,7 @@ class SparseIndexArray { SparseIndexArray(int64 max_indices, int64 rank, std::vector indices = {}); SparseIndexArray(int64 max_indices, int64 rank, - tensorflow::gtl::ArraySlice indices); + absl::Span indices); // Returns the number of elements represented by the indices stored in the // array. @@ -73,12 +73,12 @@ class SparseIndexArray { // Returns a slice that refers to the given sparse index number. The argument // must be in the range [0, element_count()). - tensorflow::gtl::ArraySlice At(int64 sparse_element_number) const; - tensorflow::gtl::MutableArraySlice At(int64 sparse_element_number); + absl::Span At(int64 sparse_element_number) const; + absl::Span At(int64 sparse_element_number); // Adds the given index at the end of the array. The new size of the // SparseIndexArray must not exceed `max_indices`. - void Append(tensorflow::gtl::ArraySlice index); + void Append(absl::Span index); // Removes all indices from the array. void Clear(); @@ -96,8 +96,8 @@ class SparseIndexArray { int64 max_indices() const { return max_indices_; } // Returns a pointer to the int64 array that holds the sparse indices. - tensorflow::gtl::MutableArraySlice mutable_data() { return &indices_; } - tensorflow::gtl::ArraySlice data() const { return indices_; } + absl::Span mutable_data() { return absl::MakeSpan(indices_); } + absl::Span data() const { return indices_; } // Sorts this sparse index array along with the set of corresponding values. // The indices and values are sorted in the lexicographic order of the @@ -115,7 +115,7 @@ class SparseIndexArray { // std::cout << v[0] << ", " << v[1] << ", " << v[2] << std::endl; // template - void SortWithValues(tensorflow::gtl::MutableArraySlice values); + void SortWithValues(absl::Span values); private: std::vector indices_; @@ -124,8 +124,7 @@ class SparseIndexArray { }; template -void SparseIndexArray::SortWithValues( - tensorflow::gtl::MutableArraySlice values) { +void SparseIndexArray::SortWithValues(absl::Span values) { int64 num_elements = index_count(); CHECK_EQ(values.size(), num_elements); std::vector sort_order; diff --git a/tensorflow/compiler/xla/sparse_index_array_test.cc b/tensorflow/compiler/xla/sparse_index_array_test.cc index 7377f88958dcb7daf3d3f4f0e07966fdc9294580..e54057c4007078c76b79fe44d5706665e266c083 100644 --- a/tensorflow/compiler/xla/sparse_index_array_test.cc +++ b/tensorflow/compiler/xla/sparse_index_array_test.cc @@ -33,7 +33,7 @@ TEST(SparseIndexArrayTest, Sort) { std::vector values = { 12.0, 13.0, 11.0, 15.0, 14.0, 16.0, }; - a.SortWithValues(&values); + a.SortWithValues(absl::MakeSpan(values)); ASSERT_EQ(a.data(), std::vector({1, 2, 3, 2, 3, 4, 3, 4, 5, 4, 5, 6, 5, 6, 7, 6, 7, 8})); ASSERT_EQ(values, std::vector({11.0, 12.0, 13.0, 14.0, 15.0, 16.0})); diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index 6b29d833dac6565eac774957221c3cc8814d54ef..8a0ae330420531b833ed670118e6b6b1056bd358 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -29,6 +29,10 @@ load("//tensorflow/compiler/xla/tests:build_defs.bzl", "generate_backend_suites" load("//tensorflow/compiler/xla/tests:build_defs.bzl", "generate_backend_test_macros") load("//tensorflow:tensorflow.bzl", "tf_cc_binary") load("//tensorflow:tensorflow.bzl", "tf_cc_test") +load( + "//tensorflow/core:platform/default/build_config_root.bzl", + "tf_cuda_tests_tags", +) # Generate test_suites for all backends, named "${backend}_tests". generate_backend_suites() @@ -69,14 +73,14 @@ cc_library( "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_dataflow_analysis", "//tensorflow/compiler/xla/service:hlo_verifier", "//tensorflow/compiler/xla/service:transfer_manager", "//tensorflow/core:lib", - "//tensorflow/core:stream_executor_headers_lib", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/types:span", ], ) @@ -99,7 +103,9 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", ], ) @@ -131,6 +137,7 @@ cc_library( "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/memory", "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", ], ) @@ -147,11 +154,31 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/service:hlo_verifier", "//tensorflow/core:lib", - "//tensorflow/core:test", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/memory", ], ) +tf_cc_test( + name = "hlo_verified_test_base_test", + srcs = ["hlo_verified_test_base_test.cc"], + deps = [ + ":hlo_test_base", + ":hlo_verified_test_base", + ":test_macros_cpu", + ":test_utils", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/service:hlo_verifier", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + tf_cc_binary( name = "local_client_aot_test_helper", srcs = ["local_client_aot_test_helper.cc"], @@ -207,6 +234,7 @@ cc_library( "//tensorflow/core:test", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", ], ) @@ -281,6 +309,7 @@ cc_library( "//tensorflow/core:stream_executor_no_cuda", "//third_party/eigen3", "@com_google_absl//absl/memory", + "@com_google_absl//absl/types:span", ], ) @@ -393,6 +422,7 @@ xla_test( "//tensorflow/core:regexp_internal", "//tensorflow/core:test", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", ], ) @@ -561,6 +591,7 @@ xla_test( "//tensorflow/core:lib", "//tensorflow/core:test", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", ], ) @@ -577,8 +608,7 @@ xla_test( "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/core:lib", - "//tensorflow/core:test", + "@com_google_absl//absl/types:span", ], ) @@ -601,8 +631,8 @@ xla_test( "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/types:span", ], ) @@ -624,12 +654,11 @@ xla_test( "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", - "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", + "@com_google_absl//absl/types:span", ], ) @@ -643,6 +672,7 @@ xla_test( ], shard_count = 48, tags = [ + "broken", "manual", "notap", ], @@ -1014,6 +1044,8 @@ xla_test( "//tensorflow/core:test", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", ], ) @@ -1123,7 +1155,6 @@ xla_test( deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array4d", - "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla:shape_util", @@ -1142,6 +1173,8 @@ xla_test( "//tensorflow/core:lib", "//tensorflow/core:test", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", ], ) @@ -1172,6 +1205,7 @@ xla_test_library( "//tensorflow/core:test", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", ], ) @@ -1428,7 +1462,6 @@ xla_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test_helpers", - "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_builder", @@ -1441,6 +1474,7 @@ xla_test( "//tensorflow/core:test", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", ], ) @@ -1455,11 +1489,11 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/types:span", ], ) @@ -1473,14 +1507,12 @@ xla_test( deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array4d", - "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test", - "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", @@ -1490,7 +1522,7 @@ xla_test( "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", - "//tensorflow/core:test", + "@com_google_absl//absl/types:span", ], ) @@ -1511,6 +1543,7 @@ xla_test( "//tensorflow/core:lib", "//tensorflow/core:test", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", ], ) @@ -1645,8 +1678,8 @@ xla_test( "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/types:span", ], ) @@ -1659,13 +1692,13 @@ xla_test( deps = [ "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", ], ) @@ -1789,7 +1822,7 @@ xla_test( tf_cc_test( name = "llvm_compiler_test", srcs = ["llvm_compiler_test.cc"], - tags = ["requires-gpu-sm35"], + tags = tf_cuda_tests_tags(), deps = [ "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:test_helpers", @@ -1825,6 +1858,7 @@ xla_test( "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/types:span", ], ) @@ -1838,15 +1872,11 @@ xla_test( "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_parser", - "//tensorflow/compiler/xla/service:hlo_runner", "//tensorflow/compiler/xla/service:platform_util", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:hlo_test_base", @@ -1857,6 +1887,7 @@ xla_test( "//tensorflow/core:test", "//third_party/eigen3", "@com_google_absl//absl/memory", + "@com_google_absl//absl/types:span", ], ) @@ -1864,10 +1895,8 @@ xla_test( name = "multioutput_fusion_test", srcs = ["multioutput_fusion_test.cc"], deps = [ - "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/service:hlo", @@ -1881,6 +1910,7 @@ xla_test( "//tensorflow/core:test", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", ], ) @@ -2009,16 +2039,15 @@ xla_test( "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test_helpers", - "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/types:span", ], ) @@ -2092,7 +2121,7 @@ tf_cc_test( name = "sample_file_test", srcs = ["sample_file_test.cc"], data = ["isolated_convolution.hlo"], - tags = ["requires-gpu-sm35"], + tags = tf_cuda_tests_tags(), deps = [ ":hlo_test_base", "//tensorflow/compiler/xla:test", @@ -2117,30 +2146,44 @@ xla_test( ":test_utils", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/container:flat_hash_set", ], ) xla_test( name = "iota_test", srcs = ["iota_test.cc"], - blacklisted_backends = [ - "cpu", - "gpu", - ], + shard_count = 30, tags = [ "enable_for_xla_interpreter", + # Require optimized builds, iota_test_cpu is very slow in fastbuild. + "optonly", ], deps = [ ":client_library_test_base", - ":literal_test_util", ":xla_internal_test_main", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "multiple_devices_on_host_test", + srcs = ["multiple_devices_on_host_test.cc"], + args = ["--xla_force_host_platform_device_count=4"], + deps = [ + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/service:cpu_plugin", + "//tensorflow/compiler/xla/service:platform_util", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/synchronization", ], ) diff --git a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc index 577fd1ab3b9268a66ea3f0c7e62b7d2644136d6e..c257566fb218d4769aec0c793efb9256b023b7ea 100644 --- a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc +++ b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/array4d.h" @@ -35,14 +36,11 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/casts.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/types.h" namespace xla { namespace { -using tensorflow::gtl::ArraySlice; - class ArrayElementwiseOpTest : public ClientLibraryTestBase { public: ErrorSpec error_spec_{0.0001, 0.0001}; @@ -228,10 +226,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantU64s) { 0x8000000000000000LL, 0x8000000000000000LL, 1}; - std::unique_ptr lhs_literal = LiteralUtil::CreateR1({lhs}); - auto lhs_param = Parameter(&b, 0, lhs_literal->shape(), "lhs_param"); + Literal lhs_literal = LiteralUtil::CreateR1({lhs}); + auto lhs_param = Parameter(&b, 0, lhs_literal.shape(), "lhs_param"); std::unique_ptr lhs_data = - client_->TransferToServer(*lhs_literal).ConsumeValueOrDie(); + client_->TransferToServer(lhs_literal).ConsumeValueOrDie(); std::vector rhs{1, 0x7FFFFFFFFFFFFFFLL, @@ -242,10 +240,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantU64s) { 0, 1, 0x8000000000000000LL}; - std::unique_ptr rhs_literal = LiteralUtil::CreateR1({rhs}); - auto rhs_param = Parameter(&b, 1, rhs_literal->shape(), "rhs_param"); + Literal rhs_literal = LiteralUtil::CreateR1({rhs}); + auto rhs_param = Parameter(&b, 1, rhs_literal.shape(), "rhs_param"); std::unique_ptr rhs_data = - client_->TransferToServer(*rhs_literal).ConsumeValueOrDie(); + client_->TransferToServer(rhs_literal).ConsumeValueOrDie(); Add(lhs_param, rhs_param); @@ -268,10 +266,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantS64s) { 1, 0, -1}; - std::unique_ptr lhs_literal = LiteralUtil::CreateR1({lhs}); - auto lhs_param = Parameter(&b, 0, lhs_literal->shape(), "lhs_param"); + Literal lhs_literal = LiteralUtil::CreateR1({lhs}); + auto lhs_param = Parameter(&b, 0, lhs_literal.shape(), "lhs_param"); std::unique_ptr lhs_data = - client_->TransferToServer(*lhs_literal).ConsumeValueOrDie(); + client_->TransferToServer(lhs_literal).ConsumeValueOrDie(); std::vector rhs{-1, 0, @@ -281,10 +279,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantS64s) { 0x7FFFFFFFFFFFFFFLL, 0x7FFFFFFFFFFFFFFFLL, 0x7FFFFFFFFFFFFFFFLL}; - std::unique_ptr rhs_literal = LiteralUtil::CreateR1({rhs}); - auto rhs_param = Parameter(&b, 1, rhs_literal->shape(), "rhs_param"); + Literal rhs_literal = LiteralUtil::CreateR1({rhs}); + auto rhs_param = Parameter(&b, 1, rhs_literal.shape(), "rhs_param"); std::unique_ptr rhs_data = - client_->TransferToServer(*rhs_literal).ConsumeValueOrDie(); + client_->TransferToServer(rhs_literal).ConsumeValueOrDie(); Sub(lhs_param, rhs_param); @@ -300,16 +298,16 @@ XLA_TEST_F(ArrayElementwiseOpTest, CmpTwoConstantU64s) { XlaBuilder b(TestName()); std::vector lhs{static_cast(0x8000000000000000ULL)}; - std::unique_ptr lhs_literal = LiteralUtil::CreateR1({lhs}); - auto lhs_param = Parameter(&b, 0, lhs_literal->shape(), "lhs_param"); + Literal lhs_literal = LiteralUtil::CreateR1({lhs}); + auto lhs_param = Parameter(&b, 0, lhs_literal.shape(), "lhs_param"); std::vector rhs{static_cast(0x7FFFFFFFFFFFFFFFULL)}; - std::unique_ptr rhs_literal = LiteralUtil::CreateR1({rhs}); - auto rhs_param = Parameter(&b, 1, rhs_literal->shape(), "rhs_param"); + Literal rhs_literal = LiteralUtil::CreateR1({rhs}); + auto rhs_param = Parameter(&b, 1, rhs_literal.shape(), "rhs_param"); Lt(lhs_param, rhs_param); - ComputeAndCompare(&b, {std::move(*lhs_literal), std::move(*rhs_literal)}); + ComputeAndCompare(&b, {std::move(lhs_literal), std::move(rhs_literal)}); } TEST_P(ArrayElementwiseOpTestParamCount, AddManyValues) { @@ -322,16 +320,16 @@ TEST_P(ArrayElementwiseOpTestParamCount, AddManyValues) { b_values.push_back(2 * i / static_cast(count + 2)); } - std::unique_ptr a_literal = LiteralUtil::CreateR1({a_values}); + Literal a_literal = LiteralUtil::CreateR1({a_values}); std::unique_ptr a_data = - client_->TransferToServer(*a_literal).ConsumeValueOrDie(); + client_->TransferToServer(a_literal).ConsumeValueOrDie(); auto a_constant = ConstantR1(&builder, a_values); - auto a_param = Parameter(&builder, 0, a_literal->shape(), "a_param"); + auto a_param = Parameter(&builder, 0, a_literal.shape(), "a_param"); - std::unique_ptr b_literal = LiteralUtil::CreateR1({b_values}); + Literal b_literal = LiteralUtil::CreateR1({b_values}); std::unique_ptr b_data = - client_->TransferToServer(*b_literal).ConsumeValueOrDie(); - auto b_constant = Parameter(&builder, 1, a_literal->shape(), "b_param"); + client_->TransferToServer(b_literal).ConsumeValueOrDie(); + auto b_constant = Parameter(&builder, 1, a_literal.shape(), "b_param"); auto b_param = ConstantR1(&builder, b_values); auto sum1 = Add(a_constant, b_constant); @@ -433,8 +431,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantZeroElementF32s) { class IntegerDivideOpTest : public ArrayElementwiseOpTest { protected: template - void TestDivRem(ArraySlice dividends, ArraySlice divisors, - ArraySlice quotients, ArraySlice remainders) { + void TestDivRem(absl::Span dividends, absl::Span divisors, + absl::Span quotients, + absl::Span remainders) { { XlaBuilder builder(TestName()); XlaOp dividend; @@ -1422,12 +1421,12 @@ XLA_TEST_F(ArrayElementwiseOpTest, PowSpecialF32) { std::vector values = {1.0f, 2.0f, 3.2f, -4.0f}; std::vector exponents = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; - std::unique_ptr param_literal = LiteralUtil::CreateR1(values); + Literal param_literal = LiteralUtil::CreateR1(values); std::unique_ptr param_data = - client_->TransferToServer(*param_literal).ConsumeValueOrDie(); + client_->TransferToServer(param_literal).ConsumeValueOrDie(); auto sum = ConstantR0(&b, 0.0f); - auto param = Parameter(&b, 0, param_literal->shape(), "param"); + auto param = Parameter(&b, 0, param_literal.shape(), "param"); for (float exponent : exponents) { sum = Add(sum, Pow(param, ConstantR0(&b, exponent))); } @@ -1450,14 +1449,14 @@ XLA_TEST_F(ArrayElementwiseOpTest, PowOfExpF32) { std::vector values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f}; std::vector values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; - std::unique_ptr literal0 = LiteralUtil::CreateR1(values0); + Literal literal0 = LiteralUtil::CreateR1(values0); std::unique_ptr data0 = - client_->TransferToServer(*literal0).ConsumeValueOrDie(); - std::unique_ptr literal1 = LiteralUtil::CreateR1(values1); + client_->TransferToServer(literal0).ConsumeValueOrDie(); + Literal literal1 = LiteralUtil::CreateR1(values1); std::unique_ptr data1 = - client_->TransferToServer(*literal1).ConsumeValueOrDie(); - auto param0 = Parameter(&b, 0, literal0->shape(), "param0"); - auto param1 = Parameter(&b, 1, literal1->shape(), "param1"); + client_->TransferToServer(literal1).ConsumeValueOrDie(); + auto param0 = Parameter(&b, 0, literal0.shape(), "param0"); + auto param1 = Parameter(&b, 1, literal1.shape(), "param1"); Pow(Exp(param0), param1); std::vector expected(values0.size()); @@ -1475,14 +1474,14 @@ XLA_TEST_F(ArrayElementwiseOpTest, LogOfPowerF32) { std::vector values0 = {1.0f, 2.0f, 3.2f, 4.0f, 0.5f, 5.7f}; std::vector values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; - std::unique_ptr literal0 = LiteralUtil::CreateR1(values0); + Literal literal0 = LiteralUtil::CreateR1(values0); std::unique_ptr data0 = - client_->TransferToServer(*literal0).ConsumeValueOrDie(); - std::unique_ptr literal1 = LiteralUtil::CreateR1(values1); + client_->TransferToServer(literal0).ConsumeValueOrDie(); + Literal literal1 = LiteralUtil::CreateR1(values1); std::unique_ptr data1 = - client_->TransferToServer(*literal1).ConsumeValueOrDie(); - auto param0 = Parameter(&b, 0, literal0->shape(), "param0"); - auto param1 = Parameter(&b, 1, literal1->shape(), "param1"); + client_->TransferToServer(literal1).ConsumeValueOrDie(); + auto param0 = Parameter(&b, 0, literal0.shape(), "param0"); + auto param1 = Parameter(&b, 1, literal1.shape(), "param1"); Log(Pow(param0, param1)); std::vector expected(values0.size()); @@ -1500,14 +1499,14 @@ XLA_TEST_F(ArrayElementwiseOpTest, MulOfExpF32) { std::vector values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f}; std::vector values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; - std::unique_ptr literal0 = LiteralUtil::CreateR1(values0); + Literal literal0 = LiteralUtil::CreateR1(values0); std::unique_ptr data0 = - client_->TransferToServer(*literal0).ConsumeValueOrDie(); - std::unique_ptr literal1 = LiteralUtil::CreateR1(values1); + client_->TransferToServer(literal0).ConsumeValueOrDie(); + Literal literal1 = LiteralUtil::CreateR1(values1); std::unique_ptr data1 = - client_->TransferToServer(*literal1).ConsumeValueOrDie(); - auto param0 = Parameter(&b, 0, literal0->shape(), "param0"); - auto param1 = Parameter(&b, 1, literal1->shape(), "param1"); + client_->TransferToServer(literal1).ConsumeValueOrDie(); + auto param0 = Parameter(&b, 0, literal0.shape(), "param0"); + auto param1 = Parameter(&b, 1, literal1.shape(), "param1"); Mul(Exp(param0), Exp(param1)); std::vector expected(values0.size()); @@ -1525,14 +1524,14 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivOfExpF32) { std::vector values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f}; std::vector values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; - std::unique_ptr literal0 = LiteralUtil::CreateR1(values0); + Literal literal0 = LiteralUtil::CreateR1(values0); std::unique_ptr data0 = - client_->TransferToServer(*literal0).ConsumeValueOrDie(); - std::unique_ptr literal1 = LiteralUtil::CreateR1(values1); + client_->TransferToServer(literal0).ConsumeValueOrDie(); + Literal literal1 = LiteralUtil::CreateR1(values1); std::unique_ptr data1 = - client_->TransferToServer(*literal1).ConsumeValueOrDie(); - auto param0 = Parameter(&b, 0, literal0->shape(), "param0"); - auto param1 = Parameter(&b, 1, literal1->shape(), "param1"); + client_->TransferToServer(literal1).ConsumeValueOrDie(); + auto param0 = Parameter(&b, 0, literal0.shape(), "param0"); + auto param1 = Parameter(&b, 1, literal1.shape(), "param1"); Div(param0, Exp(param1)); std::vector expected(values0.size()); @@ -1551,20 +1550,20 @@ XLA_TEST_F(ArrayElementwiseOpTest, Div3_lhs_F32) { std::vector values1 = {0.1f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; std::vector values2 = {0.1f, 1.1f, 6.9f, 12.5f, -15.0f, -0.5f}; - std::unique_ptr literal0 = LiteralUtil::CreateR1(values0); + Literal literal0 = LiteralUtil::CreateR1(values0); std::unique_ptr data0 = - client_->TransferToServer(*literal0).ConsumeValueOrDie(); + client_->TransferToServer(literal0).ConsumeValueOrDie(); - std::unique_ptr literal1 = LiteralUtil::CreateR1(values1); + Literal literal1 = LiteralUtil::CreateR1(values1); std::unique_ptr data1 = - client_->TransferToServer(*literal1).ConsumeValueOrDie(); + client_->TransferToServer(literal1).ConsumeValueOrDie(); - std::unique_ptr literal2 = LiteralUtil::CreateR1(values2); + Literal literal2 = LiteralUtil::CreateR1(values2); std::unique_ptr data2 = - client_->TransferToServer(*literal2).ConsumeValueOrDie(); - auto param0 = Parameter(&b, 0, literal0->shape(), "param0"); - auto param1 = Parameter(&b, 1, literal1->shape(), "param1"); - auto param2 = Parameter(&b, 2, literal2->shape(), "param2"); + client_->TransferToServer(literal2).ConsumeValueOrDie(); + auto param0 = Parameter(&b, 0, literal0.shape(), "param0"); + auto param1 = Parameter(&b, 1, literal1.shape(), "param1"); + auto param2 = Parameter(&b, 2, literal2.shape(), "param2"); Div(Div(param0, param1), param2); std::vector expected(values0.size()); @@ -1583,21 +1582,21 @@ XLA_TEST_F(ArrayElementwiseOpTest, Div3_rhs_F32) { std::vector values1 = {0.1f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; std::vector values2 = {0.1f, 1.1f, 6.9f, 12.5f, -15.0f, -0.5f}; - std::unique_ptr literal0 = LiteralUtil::CreateR1(values0); + Literal literal0 = LiteralUtil::CreateR1(values0); std::unique_ptr data0 = - client_->TransferToServer(*literal0).ConsumeValueOrDie(); + client_->TransferToServer(literal0).ConsumeValueOrDie(); - std::unique_ptr literal1 = LiteralUtil::CreateR1(values1); + Literal literal1 = LiteralUtil::CreateR1(values1); std::unique_ptr data1 = - client_->TransferToServer(*literal1).ConsumeValueOrDie(); + client_->TransferToServer(literal1).ConsumeValueOrDie(); - std::unique_ptr literal2 = LiteralUtil::CreateR1(values2); + Literal literal2 = LiteralUtil::CreateR1(values2); std::unique_ptr data2 = - client_->TransferToServer(*literal2).ConsumeValueOrDie(); + client_->TransferToServer(literal2).ConsumeValueOrDie(); - auto param0 = Parameter(&b, 0, literal0->shape(), "param0"); - auto param1 = Parameter(&b, 1, literal1->shape(), "param1"); - auto param2 = Parameter(&b, 2, literal2->shape(), "param2"); + auto param0 = Parameter(&b, 0, literal0.shape(), "param0"); + auto param1 = Parameter(&b, 1, literal1.shape(), "param1"); + auto param2 = Parameter(&b, 2, literal2.shape(), "param2"); Div(param0, Div(param1, param2)); std::vector expected(values0.size()); @@ -1616,21 +1615,21 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivOfPowerF32) { std::vector values1 = {0.1f, 1.0f, 2.0f, 0.5f, 1.0f, 0.5f}; std::vector values2 = {0.1f, 1.1f, 6.9f, 9.5f, -11.0f, -0.5f}; - std::unique_ptr literal0 = LiteralUtil::CreateR1(values0); + Literal literal0 = LiteralUtil::CreateR1(values0); std::unique_ptr data0 = - client_->TransferToServer(*literal0).ConsumeValueOrDie(); + client_->TransferToServer(literal0).ConsumeValueOrDie(); - std::unique_ptr literal1 = LiteralUtil::CreateR1(values1); + Literal literal1 = LiteralUtil::CreateR1(values1); std::unique_ptr data1 = - client_->TransferToServer(*literal1).ConsumeValueOrDie(); + client_->TransferToServer(literal1).ConsumeValueOrDie(); - std::unique_ptr literal2 = LiteralUtil::CreateR1(values2); + Literal literal2 = LiteralUtil::CreateR1(values2); std::unique_ptr data2 = - client_->TransferToServer(*literal2).ConsumeValueOrDie(); + client_->TransferToServer(literal2).ConsumeValueOrDie(); - auto param0 = Parameter(&b, 0, literal0->shape(), "param0"); - auto param1 = Parameter(&b, 1, literal1->shape(), "param1"); - auto param2 = Parameter(&b, 2, literal2->shape(), "param2"); + auto param0 = Parameter(&b, 0, literal0.shape(), "param0"); + auto param1 = Parameter(&b, 1, literal1.shape(), "param1"); + auto param2 = Parameter(&b, 2, literal2.shape(), "param2"); Div(param0, Pow(param1, param2)); std::vector expected(values0.size()); @@ -1650,26 +1649,26 @@ XLA_TEST_F(ArrayElementwiseOpTest, Div4F32) { std::vector values2 = {0.1f, 1.1f, 6.9f, 12.5f, -15.0f, -0.5f}; std::vector values3 = {2.1f, 3.1f, 9.9f, -4.5f, -11.0f, -21.5f}; - std::unique_ptr literal0 = LiteralUtil::CreateR1(values0); + Literal literal0 = LiteralUtil::CreateR1(values0); std::unique_ptr data0 = - client_->TransferToServer(*literal0).ConsumeValueOrDie(); + client_->TransferToServer(literal0).ConsumeValueOrDie(); - std::unique_ptr literal1 = LiteralUtil::CreateR1(values1); + Literal literal1 = LiteralUtil::CreateR1(values1); std::unique_ptr data1 = - client_->TransferToServer(*literal1).ConsumeValueOrDie(); + client_->TransferToServer(literal1).ConsumeValueOrDie(); - std::unique_ptr literal2 = LiteralUtil::CreateR1(values2); + Literal literal2 = LiteralUtil::CreateR1(values2); std::unique_ptr data2 = - client_->TransferToServer(*literal2).ConsumeValueOrDie(); + client_->TransferToServer(literal2).ConsumeValueOrDie(); - std::unique_ptr literal3 = LiteralUtil::CreateR1(values3); + Literal literal3 = LiteralUtil::CreateR1(values3); std::unique_ptr data3 = - client_->TransferToServer(*literal3).ConsumeValueOrDie(); + client_->TransferToServer(literal3).ConsumeValueOrDie(); - auto param0 = Parameter(&b, 0, literal0->shape(), "param0"); - auto param1 = Parameter(&b, 1, literal1->shape(), "param1"); - auto param2 = Parameter(&b, 2, literal2->shape(), "param2"); - auto param3 = Parameter(&b, 3, literal3->shape(), "param2"); + auto param0 = Parameter(&b, 0, literal0.shape(), "param0"); + auto param1 = Parameter(&b, 1, literal1.shape(), "param1"); + auto param2 = Parameter(&b, 2, literal2.shape(), "param2"); + auto param3 = Parameter(&b, 3, literal3.shape(), "param2"); Div(Div(param0, param1), Div(param2, param3)); std::vector expected(values0.size()); @@ -2096,18 +2095,18 @@ XLA_TEST_F(ArrayElementwiseOpTest, ClampU32ScalarVector) { XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersF32s) { XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = + Literal param0_literal = LiteralUtil::CreateR1({1.1f, 2.2f, 3.3f, 5.5f}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); - std::unique_ptr param1_literal = + Literal param1_literal = LiteralUtil::CreateR1({7.2f, 2.3f, 3.4f, 5.6f}); std::unique_ptr param1_data = - client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); + client_->TransferToServer(param1_literal).ConsumeValueOrDie(); - auto p0 = Parameter(&builder, 0, param0_literal->shape(), "param0"); - auto p1 = Parameter(&builder, 1, param1_literal->shape(), "param1"); + auto p0 = Parameter(&builder, 0, param0_literal.shape(), "param0"); + auto p1 = Parameter(&builder, 1, param1_literal.shape(), "param1"); Add(p0, p1); ComputeAndCompareR1(&builder, {8.3f, 4.5f, 6.7f, 11.1f}, @@ -2118,18 +2117,18 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersF32s) { XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersZeroElementF32s) { XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = + Literal param0_literal = LiteralUtil::CreateR3FromArray3D(Array3D(0, 7, 0)); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); - std::unique_ptr param1_literal = + Literal param1_literal = LiteralUtil::CreateR3FromArray3D(Array3D(0, 7, 0)); std::unique_ptr param1_data = - client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); + client_->TransferToServer(param1_literal).ConsumeValueOrDie(); - auto p0 = Parameter(&builder, 0, param0_literal->shape(), "param0"); - auto p1 = Parameter(&builder, 1, param1_literal->shape(), "param1"); + auto p0 = Parameter(&builder, 0, param0_literal.shape(), "param0"); + auto p1 = Parameter(&builder, 1, param1_literal.shape(), "param1"); Add(p0, p1); Array3D expected(0, 7, 0); @@ -2140,13 +2139,13 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersZeroElementF32s) { XLA_TEST_F(ArrayElementwiseOpTest, AddParameterToConstantF32s) { XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = + Literal param0_literal = LiteralUtil::CreateR1({1.1f, 2.2f, 3.3f, 5.5f}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); auto a = ConstantR1(&builder, {1.1f, 2.2f, 3.3f, 4.4f}); - auto p = Parameter(&builder, 0, param0_literal->shape(), "param0"); + auto p = Parameter(&builder, 0, param0_literal.shape(), "param0"); Add(a, p); ComputeAndCompareR1(&builder, {2.2f, 4.4f, 6.6f, 9.9f}, @@ -2206,9 +2205,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, TanhF32sVector) { 0.08, -1.24, -0.92, 0.49, 1.17, -0.45, -1.31, -1.44, -0.13, -1.31, -0.79, 1.41, 1.21, 1.05}); TF_ASSERT_OK_AND_ASSIGN(auto input_data, - client_->TransferToServer(*input_literal)); + client_->TransferToServer(input_literal)); - auto input = Parameter(&builder, 0, input_literal->shape(), "input"); + auto input = Parameter(&builder, 0, input_literal.shape(), "input"); Tanh(input); ComputeAndCompareR1( @@ -2239,7 +2238,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, ExpF32sVector) { // Just to help make sense of the scales here -- exp(89) saturates float32 and // exp(-10) is smaller than our error spec. - std::unique_ptr input_literal = LiteralUtil::CreateR1( + Literal input_literal = LiteralUtil::CreateR1( {1.02, -0.32, 0.85, 0.9, 1.23, -0.91, -0.49, 0.8, -1.31, -1.44, -0.13, -1.31, -0.79, 1.41, 1.21, 1.05, -195.6, -194.5, -193.4, -192.3, -191.2, -190.1, -189.0, -187.9, -19.6, -18.5, -17.4, @@ -2252,16 +2251,16 @@ XLA_TEST_F(ArrayElementwiseOpTest, ExpF32sVector) { 78.3, 79.4, 80.5, 81.6, 82.7, 83.8, 84.9, 85.2, 86.3, 86.4, 86.5, 87.6, 87.7, 87.8, 87.9}); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr input_data, - client_->TransferToServer(*input_literal)); + client_->TransferToServer(input_literal)); - auto input = Parameter(&builder, 0, input_literal->shape(), "input"); + auto input = Parameter(&builder, 0, input_literal.shape(), "input"); Exp(input); std::vector expected_result; - int64 input_size = input_literal->shape().dimensions(0); + int64 input_size = input_literal.shape().dimensions(0); expected_result.reserve(input_size); for (int64 i = 0; i < input_size; i++) { - expected_result.push_back(std::exp(input_literal->Get({i}))); + expected_result.push_back(std::exp(input_literal.Get({i}))); } ComputeAndCompareR1(&builder, expected_result, {input_data.get()}, @@ -2273,7 +2272,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, LogF32sVector) { // implementation on XLA CPU. XlaBuilder builder(TestName()); - std::unique_ptr input_literal = LiteralUtil::CreateR1( + Literal input_literal = LiteralUtil::CreateR1( {-1.29, -1.41, -1.25, -13.5, -11.7, -17.9, -198, -167, 1.29, 1.41, 1.25, 13.5, 11.7, 17.9, 198, 167, 1.27e+03, 1.33e+03, 1.74e+03, 1.6e+04, 1.84e+04, @@ -2290,16 +2289,16 @@ XLA_TEST_F(ArrayElementwiseOpTest, LogF32sVector) { 1.7e+31, 1.44e+31, 1.1e+31, 1.4e+32, 1.67e+32, 1.96e+33, 1.11e+33, 1.19e+33, 1.61e+34, 1.05e+34, 1.88e+34, 1.67e+35, 1.7e+35}); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr input_data, - client_->TransferToServer(*input_literal)); + client_->TransferToServer(input_literal)); - auto input = Parameter(&builder, 0, input_literal->shape(), "input"); + auto input = Parameter(&builder, 0, input_literal.shape(), "input"); Log(input); std::vector expected_result; - int64 input_size = input_literal->shape().dimensions(0); + int64 input_size = input_literal.shape().dimensions(0); expected_result.reserve(input_size); for (int64 i = 0; i < input_size; i++) { - expected_result.push_back(std::log(input_literal->Get({i}))); + expected_result.push_back(std::log(input_literal.Get({i}))); } ComputeAndCompareR1(&builder, expected_result, {input_data.get()}, @@ -2465,10 +2464,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Eq) { auto cmp_dim_1 = Eq(v, m, /*broadcast_dimensions=*/{0}); Tuple(&builder, {cmp_dim_0, cmp_dim_1}); - auto expected = LiteralUtil::MakeTuple( - {LiteralUtil::CreateR2({{true, true}, {true, false}}).get(), - LiteralUtil::CreateR2({{true, false}, {false, false}}).get()}); - ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); + auto expected = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR2({{true, true}, {true, false}}), + LiteralUtil::CreateR2({{true, false}, {false, false}})}); + ComputeAndCompareTuple(&builder, expected, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Ne) { @@ -2821,10 +2820,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, R4_16x16x2x2_Plus_R1_16) { std::iota(r1.begin(), r1.end(), 1.0); XlaBuilder builder(TestName()); - std::unique_ptr a_literal = - LiteralUtil::CreateR4FromArray4DWithLayout( - r4, LayoutUtil::MakeLayout({0, 1, 2, 3})); - auto a = ConstantLiteral(&builder, *a_literal); + Literal a_literal = LiteralUtil::CreateR4FromArray4DWithLayout( + r4, LayoutUtil::MakeLayout({0, 1, 2, 3})); + auto a = ConstantLiteral(&builder, a_literal); auto b = ConstantR1(&builder, r1); Add(a, b, {1}); @@ -2886,11 +2884,11 @@ XLA_TEST_F(ArrayElementwiseOpTest, ImplictBroadcastInFusedExpressions) { XlaBuilder builder(TestName()); auto x_literal = LiteralUtil::CreateR1({1, 2, 3}); auto y_literal = LiteralUtil::CreateR1({4, 5}); - auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie(); - auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie(); + auto x_data = client_->TransferToServer(x_literal).ConsumeValueOrDie(); + auto y_data = client_->TransferToServer(y_literal).ConsumeValueOrDie(); - auto x = Parameter(&builder, 0, x_literal->shape(), "x"); - auto y = Parameter(&builder, 1, y_literal->shape(), "y"); + auto x = Parameter(&builder, 0, x_literal.shape(), "x"); + auto y = Parameter(&builder, 1, y_literal.shape(), "y"); auto slice = Slice(x, {1}, {2}, {1}); Sub(slice, y); diff --git a/tensorflow/compiler/xla/tests/batch_normalization_test.cc b/tensorflow/compiler/xla/tests/batch_normalization_test.cc index ac90a3adb6dbad30e3ef0b11438fb9a6fd6f8574..bc2ba151a38f1ab000b342dcd4bdd8f53d9ce9a9 100644 --- a/tensorflow/compiler/xla/tests/batch_normalization_test.cc +++ b/tensorflow/compiler/xla/tests/batch_normalization_test.cc @@ -63,7 +63,7 @@ class BatchNormalizationTest {5.0f, 4.4f}, // p2 }); input_array_.FillWithPZ(pz); - input_literal_ = std::move(*LiteralUtil::CreateR4FromArray4D(input_array_)); + input_literal_ = LiteralUtil::CreateR4FromArray4D(input_array_); CHECK_EQ(kSamples, input_array_.planes()); CHECK_EQ(kZ, input_array_.depth()); CHECK_EQ(kY, input_array_.height()); @@ -242,14 +242,13 @@ XLA_TEST_P(BatchNormalizationTest, BasicTraining) { BatchNormTraining(operand, scale, offset, /*epsilon=*/0.001, kFeatureIndex); - auto expected = LiteralUtil::MakeTuple( + auto expected = LiteralUtil::MakeTupleFromSlices( {LiteralUtil::CreateR4({{{{-1.6f, -2.0f}}, {{0.1f, 0.6f}}}, - {{{1.9f, 3.3f}}, {{3.7f, 6.0f}}}}) - .get(), - LiteralUtil::CreateR1({4, 5}).get(), - LiteralUtil::CreateR1({5, 5}).get()}); + {{{1.9f, 3.3f}}, {{3.7f, 6.0f}}}}), + LiteralUtil::CreateR1({4, 5}), + LiteralUtil::CreateR1({5, 5})}); - ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.1)); + ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.1)); } XLA_TEST_P(BatchNormalizationTest, BasicTrainingOnDimension2) { @@ -267,14 +266,13 @@ XLA_TEST_P(BatchNormalizationTest, BasicTrainingOnDimension2) { BatchNormTraining(operand, scale, offset, /*epsilon=*/0.001, kFeatureIndex); - auto expected = LiteralUtil::MakeTuple( + auto expected = LiteralUtil::MakeTupleFromSlices( {LiteralUtil::CreateR4({{{{-1.6f}, {-2.0f}}, {{0.1f}, {0.6f}}}, - {{{1.9f}, {3.3f}}, {{3.7f}, {6.0f}}}}) - .get(), - LiteralUtil::CreateR1({4, 5}).get(), - LiteralUtil::CreateR1({5, 5}).get()}); + {{{1.9f}, {3.3f}}, {{3.7f}, {6.0f}}}}), + LiteralUtil::CreateR1({4, 5}), + LiteralUtil::CreateR1({5, 5})}); - ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.1)); + ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.1)); } XLA_TEST_P(BatchNormalizationTest, TrainingWithFeatureOnLowDimension) { @@ -298,13 +296,12 @@ XLA_TEST_P(BatchNormalizationTest, TrainingWithFeatureOnLowDimension) { BatchNormTraining(h0, h1, h2, /*epsilon=*/1, kFeatureIndex); - auto expected = LiteralUtil::MakeTuple( - {LiteralUtil::CreateR3FromArray3D(Array3D(260, 2, 2, 1.0f)) - .get(), - LiteralUtil::CreateR1(std::vector(260, 1.0f)).get(), - LiteralUtil::CreateR1(std::vector(260, 0.0f)).get()}); + auto expected = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR3FromArray3D(Array3D(260, 2, 2, 1.0f)), + LiteralUtil::CreateR1(std::vector(260, 1.0f)), + LiteralUtil::CreateR1(std::vector(260, 0.0f))}); - ComputeAndCompareTuple(&builder, *expected, + ComputeAndCompareTuple(&builder, expected, {operand.get(), scale.get(), offset.get()}, ErrorSpec(0.1)); } @@ -331,14 +328,13 @@ XLA_TEST_P(BatchNormalizationTest, LargeEpsilonTest) { BatchNormTraining(h0, h1, h2, /*epsilon=*/-100, kFeatureIndex); - auto expected = LiteralUtil::MakeTuple( + auto expected = LiteralUtil::MakeTupleFromSlices( {LiteralUtil::CreateR3FromArray3D( - {{{-3.0f}, {-1.0f}, {1.0f}, {3.0f}}}) - .get(), - LiteralUtil::CreateR1(std::vector(1, 15.0f)).get(), - LiteralUtil::CreateR1(std::vector(1, 125.0f)).get()}); + {{{-3.0f}, {-1.0f}, {1.0f}, {3.0f}}}), + LiteralUtil::CreateR1(std::vector(1, 15.0f)), + LiteralUtil::CreateR1(std::vector(1, 125.0f))}); - ComputeAndCompareTuple(&builder, *expected, + ComputeAndCompareTuple(&builder, expected, {operand.get(), scale.get(), offset.get()}, ErrorSpec(0.1)); } @@ -363,14 +359,13 @@ XLA_TEST_P(BatchNormalizationTest, BatchNormGradBasic) { BatchNormGrad(operand, scale, mean, var, grad_output, /*epsilon=*/0.0, kFeatureIndex); - auto expected = LiteralUtil::MakeTuple( + auto expected = LiteralUtil::MakeTupleFromSlices( {LiteralUtil::CreateR4({{{{-3.f}, {-3.f}}, {{-1.f}, {-1.f}}}, - {{{1.f}, {1.f}}, {{3.f}, {3.f}}}}) - .get(), - LiteralUtil::CreateR1({0, 0}).get(), - LiteralUtil::CreateR1({16, 20}).get()}); + {{{1.f}, {1.f}}, {{3.f}, {3.f}}}}), + LiteralUtil::CreateR1({0, 0}), + LiteralUtil::CreateR1({16, 20})}); - ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.1)); + ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.1)); } struct BatchNormTestParam { @@ -522,22 +517,22 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedTrainingTests) { auto input_literal = LiteralUtil::CreateR4FromArray4D(input_array); auto input_activations = - Parameter(&builder, 0, input_literal->shape(), "input"); + Parameter(&builder, 0, input_literal.shape(), "input"); auto scale_activations = - Parameter(&builder, 1, scale_literal->shape(), "offset"); + Parameter(&builder, 1, scale_literal.shape(), "offset"); auto offset_activations = - Parameter(&builder, 2, offset_literal->shape(), "scale"); + Parameter(&builder, 2, offset_literal.shape(), "scale"); - auto expected = LiteralUtil::MakeTuple( - {expected_normalized.get(), LiteralUtil::CreateR1(mean).get(), - LiteralUtil::CreateR1(var).get()}); + auto expected = LiteralUtil::MakeTupleFromSlices( + {expected_normalized, LiteralUtil::CreateR1(mean), + LiteralUtil::CreateR1(var)}); std::unique_ptr input_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + client_->TransferToServer(input_literal).ConsumeValueOrDie(); std::unique_ptr scale_data = - client_->TransferToServer(*scale_literal).ConsumeValueOrDie(); + client_->TransferToServer(scale_literal).ConsumeValueOrDie(); std::unique_ptr offset_data = - client_->TransferToServer(*offset_literal).ConsumeValueOrDie(); + client_->TransferToServer(offset_literal).ConsumeValueOrDie(); BatchNormTraining(input_activations, scale_activations, offset_activations, epsilon, feature_index); @@ -547,7 +542,7 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedTrainingTests) { // testcase. execution_options_.mutable_debug_options()->clear_xla_disable_hlo_passes(); ComputeAndCompareTuple( - &builder, *expected, + &builder, expected, {input_data.get(), scale_data.get(), offset_data.get()}, ErrorSpec(0.01, 1)); } @@ -622,27 +617,27 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedInferencingTests) { auto input_literal = LiteralUtil::CreateR4FromArray4D(input_array); auto input_activations = - Parameter(&builder, 0, input_literal->shape(), "input"); + Parameter(&builder, 0, input_literal.shape(), "input"); auto scale_activations = - Parameter(&builder, 1, scale_literal->shape(), "offset"); + Parameter(&builder, 1, scale_literal.shape(), "offset"); auto offset_activations = - Parameter(&builder, 2, offset_literal->shape(), "scale"); - auto mean_activations = Parameter(&builder, 3, mean_literal->shape(), "mean"); + Parameter(&builder, 2, offset_literal.shape(), "scale"); + auto mean_activations = Parameter(&builder, 3, mean_literal.shape(), "mean"); auto variance_activations = - Parameter(&builder, 4, var_literal->shape(), "variance"); + Parameter(&builder, 4, var_literal.shape(), "variance"); Array4D expected = normalized; std::unique_ptr input_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + client_->TransferToServer(input_literal).ConsumeValueOrDie(); std::unique_ptr scale_data = - client_->TransferToServer(*scale_literal).ConsumeValueOrDie(); + client_->TransferToServer(scale_literal).ConsumeValueOrDie(); std::unique_ptr offset_data = - client_->TransferToServer(*offset_literal).ConsumeValueOrDie(); + client_->TransferToServer(offset_literal).ConsumeValueOrDie(); std::unique_ptr mean_data = - client_->TransferToServer(*mean_literal).ConsumeValueOrDie(); + client_->TransferToServer(mean_literal).ConsumeValueOrDie(); std::unique_ptr variance_data = - client_->TransferToServer(*var_literal).ConsumeValueOrDie(); + client_->TransferToServer(var_literal).ConsumeValueOrDie(); BatchNormInference(input_activations, scale_activations, offset_activations, mean_activations, variance_activations, epsilon, @@ -811,40 +806,37 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedGradTests) { auto grad_output_literal = LiteralUtil::CreateR4FromArray4D(grad_output_array); - auto input_parameter = - Parameter(&builder, 0, input_literal->shape(), "input"); - auto scale_parameter = - Parameter(&builder, 1, scale_literal->shape(), "scale"); - auto mean_parameter = Parameter(&builder, 2, mean_literal->shape(), "mean"); - auto var_parameter = Parameter(&builder, 3, var_literal->shape(), "variance"); + auto input_parameter = Parameter(&builder, 0, input_literal.shape(), "input"); + auto scale_parameter = Parameter(&builder, 1, scale_literal.shape(), "scale"); + auto mean_parameter = Parameter(&builder, 2, mean_literal.shape(), "mean"); + auto var_parameter = Parameter(&builder, 3, var_literal.shape(), "variance"); auto grad_output_parameter = - Parameter(&builder, 4, grad_output_literal->shape(), "grad_output"); + Parameter(&builder, 4, grad_output_literal.shape(), "grad_output"); std::unique_ptr input_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + client_->TransferToServer(input_literal).ConsumeValueOrDie(); std::unique_ptr scale_data = - client_->TransferToServer(*scale_literal).ConsumeValueOrDie(); + client_->TransferToServer(scale_literal).ConsumeValueOrDie(); std::unique_ptr mean_data = - client_->TransferToServer(*mean_literal).ConsumeValueOrDie(); + client_->TransferToServer(mean_literal).ConsumeValueOrDie(); std::unique_ptr var_data = - client_->TransferToServer(*var_literal).ConsumeValueOrDie(); + client_->TransferToServer(var_literal).ConsumeValueOrDie(); std::unique_ptr grad_output_data = - client_->TransferToServer(*grad_output_literal).ConsumeValueOrDie(); + client_->TransferToServer(grad_output_literal).ConsumeValueOrDie(); BatchNormGrad(input_parameter, scale_parameter, mean_parameter, var_parameter, grad_output_parameter, epsilon, feature_index); - auto expected = - LiteralUtil::MakeTuple({expected_grad_activation.get(), - LiteralUtil::CreateR1(grad_scale).get(), - LiteralUtil::CreateR1(grad_offset).get()}); + auto expected = LiteralUtil::MakeTupleFromSlices( + {expected_grad_activation, LiteralUtil::CreateR1(grad_scale), + LiteralUtil::CreateR1(grad_offset)}); // Run all HLO passes during this test. In particular, ClientLibraryTestBase // disables constant folding, but we want it enabled for our zero-sized tensor // testcase. execution_options_.mutable_debug_options()->clear_xla_disable_hlo_passes(); - ComputeAndCompareTuple(&builder, *expected, + ComputeAndCompareTuple(&builder, expected, {input_data.get(), scale_data.get(), mean_data.get(), var_data.get(), grad_output_data.get()}, ErrorSpec(0.01, 1)); diff --git a/tensorflow/compiler/xla/tests/bfloat16_test.cc b/tensorflow/compiler/xla/tests/bfloat16_test.cc index 6c20f654fe3df6a28e9633cd832c11b487894bad..e9728e636f0ee032416b2da17a3ea83c5bb18083 100644 --- a/tensorflow/compiler/xla/tests/bfloat16_test.cc +++ b/tensorflow/compiler/xla/tests/bfloat16_test.cc @@ -65,7 +65,7 @@ XLA_TEST_F(Bfloat16Test, LogOperation) { Log(x); ComputeAndCompareR0(&builder, static_cast(1.387f), {}, - error_spec_); + ErrorSpec(0.01, 0.01)); } XLA_TEST_F(Bfloat16Test, NegateScalarF16) { @@ -95,22 +95,19 @@ XLA_TEST_F(Bfloat16Test, BatchNormTraining) { BatchNormTraining(operand, scale, offset, /*epsilon=*/0.001, kFeatureIndex); - auto expected = LiteralUtil::MakeTuple( + auto expected = LiteralUtil::MakeTupleFromSlices( {LiteralUtil::CreateR4( {{{{static_cast(-1.6875f)}, {static_cast(-2.04f)}}, {{static_cast(0.105f)}, {static_cast(0.66f)}}}, {{{static_cast(1.89f)}, {static_cast(3.35f)}}, - {{static_cast(3.7f)}, {static_cast(6.04f)}}}}) - .get(), + {{static_cast(3.7f)}, {static_cast(6.04f)}}}}), LiteralUtil::CreateR1( - {static_cast(4), static_cast(5)}) - .get(), + {static_cast(4), static_cast(5)}), LiteralUtil::CreateR1( - {static_cast(5), static_cast(5)}) - .get()}); + {static_cast(5), static_cast(5)})}); - ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.01)); + ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.01, 0.02)); } XLA_TEST_F(Bfloat16Test, BatchNormGrad) { @@ -139,21 +136,18 @@ XLA_TEST_F(Bfloat16Test, BatchNormGrad) { BatchNormGrad(operand, scale, mean, var, grad_output, /*epsilon=*/0.0, kFeatureIndex); - auto expected = LiteralUtil::MakeTuple( + auto expected = LiteralUtil::MakeTupleFromSlices( {LiteralUtil::CreateR4( {{{{static_cast(-3.f)}, {static_cast(-3.f)}}, {{static_cast(-1.f)}, {static_cast(-1.f)}}}, {{{static_cast(1.f)}, {static_cast(1.f)}}, - {{static_cast(3.f)}, {static_cast(3.f)}}}}) - .get(), + {{static_cast(3.f)}, {static_cast(3.f)}}}}), LiteralUtil::CreateR1( - {static_cast(0), static_cast(0)}) - .get(), + {static_cast(0), static_cast(0)}), LiteralUtil::CreateR1( - {static_cast(16), static_cast(20)}) - .get()}); + {static_cast(16), static_cast(20)})}); - ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.01)); + ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.01)); } } // namespace diff --git a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc index 1d28e85b16596b0ec2717138fb2081878203e8b2..dde19fb65d65064c9452a6ac49c70e20cf113336 100644 --- a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc +++ b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc @@ -53,29 +53,31 @@ class BroadcastSimpleTest : public ClientLibraryTestBase { } } - std::unique_ptr MakeR3Data( - tensorflow::gtl::ArraySlice bounds, - tensorflow::gtl::ArraySlice minor_to_major, Shape* r3_shape, - Array3D* r3_array, float start, float end, int seed) { + std::unique_ptr MakeR3Data(absl::Span bounds, + absl::Span minor_to_major, + Shape* r3_shape, + Array3D* r3_array, float start, + float end, int seed) { *r3_shape = ShapeUtil::MakeShapeWithLayout(F32, bounds, minor_to_major); r3_array->FillRandom(start, end, seed); - auto r3_data = LiteralUtil::CreateR3FromArray3D(*r3_array)->Relayout( + auto r3_data = LiteralUtil::CreateR3FromArray3D(*r3_array).Relayout( LayoutUtil::MakeLayout(minor_to_major)); std::unique_ptr r3_global_data = - client_->TransferToServer(*r3_data).ConsumeValueOrDie(); + client_->TransferToServer(r3_data).ConsumeValueOrDie(); return r3_global_data; } - std::unique_ptr MakeR2Data( - tensorflow::gtl::ArraySlice bounds, - tensorflow::gtl::ArraySlice minor_to_major, Shape* r2_shape, - Array2D* r2_array, float start, float end, int seed) { + std::unique_ptr MakeR2Data(absl::Span bounds, + absl::Span minor_to_major, + Shape* r2_shape, + Array2D* r2_array, float start, + float end, int seed) { *r2_shape = ShapeUtil::MakeShapeWithLayout(F32, bounds, minor_to_major); r2_array->FillRandom(start, end, seed); - auto r2_data = LiteralUtil::CreateR2FromArray2D(*r2_array)->Relayout( + auto r2_data = LiteralUtil::CreateR2FromArray2D(*r2_array).Relayout( LayoutUtil::MakeLayout(minor_to_major)); std::unique_ptr r2_global_data = - client_->TransferToServer(*r2_data).ConsumeValueOrDie(); + client_->TransferToServer(r2_data).ConsumeValueOrDie(); return r2_global_data; } @@ -291,7 +293,7 @@ XLA_TEST_F(BroadcastSimpleTest, InDimensionAndDegenerateBroadcasting) { XlaBuilder b(TestName()); Add(ConstantR2(&b, {{1.0, 5.0}}), - ConstantLiteral(&b, *LiteralUtil::CreateR3( + ConstantLiteral(&b, LiteralUtil::CreateR3( {{{2.0}, {3.0}, {4.0}}, {{5.0}, {6.0}, {7.0}}})), /*broadcast_dimensions=*/{1, 2}); @@ -299,7 +301,7 @@ XLA_TEST_F(BroadcastSimpleTest, InDimensionAndDegenerateBroadcasting) { LiteralUtil::CreateR3({{{3.0, 7.0}, {4.0, 8.0}, {5.0, 9.0}}, {{6.0, 10.0}, {7.0, 11.0}, {8.0, 12.0}}}); - ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); + ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001)); } struct R3ImplicitBroadcastSpec { @@ -348,7 +350,7 @@ XLA_TEST_P(BroadcastR3ImplicitTest, Doit) { Array3D expected_array(spec.output_bounds[0], spec.output_bounds[1], spec.output_bounds[2]); - auto Each = ([&](tensorflow::gtl::ArraySlice indices, float* value) { + auto Each = ([&](absl::Span indices, float* value) { float r3_implicit = r3_implicit_array(indices[0] % spec.input_bounds[0], indices[1] % spec.input_bounds[1], indices[2] % spec.input_bounds[2]); @@ -368,8 +370,7 @@ XLA_TEST_P(BroadcastR3ImplicitTest, Doit) { } auto expected = LiteralUtil::CreateR3FromArray3D(expected_array); ComputeAndCompareLiteral( - &builder, *expected, - {r3_implicit_global_data.get(), r3_global_data.get()}, + &builder, expected, {r3_implicit_global_data.get(), r3_global_data.get()}, ErrorSpec(1e-7, 1e-7)); } @@ -393,89 +394,89 @@ XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_1_2) { auto expected = LiteralUtil::CreateR3({{{2, 3}, {4, 5}}, {{7, 8}, {9, 10}}}); - ComputeAndCompareLiteral(&b, *expected, {r3.get(), r1.get()}, + ComputeAndCompareLiteral(&b, expected, {r3.get(), r1.get()}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_1) { XlaBuilder b(TestName()); - auto r1 = ConstantLiteral(&b, *LiteralUtil::CreateR3({{{1, 2}}})); + auto r1 = ConstantLiteral(&b, LiteralUtil::CreateR3({{{1, 2}}})); auto r3 = ConstantLiteral( - &b, *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + &b, LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); Add(r3, r1); auto expected = LiteralUtil::CreateR3({{{2, 4}, {4, 6}}, {{6, 8}, {8, 10}}}); - ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); + ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_2) { XlaBuilder b(TestName()); - auto r1 = ConstantLiteral(&b, *LiteralUtil::CreateR3({{{1}, {2}}})); + auto r1 = ConstantLiteral(&b, LiteralUtil::CreateR3({{{1}, {2}}})); auto r3 = ConstantLiteral( - &b, *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + &b, LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); Add(r3, r1); auto expected = LiteralUtil::CreateR3({{{2, 3}, {5, 6}}, {{6, 7}, {9, 10}}}); - ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); + ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0) { XlaBuilder b(TestName()); auto r1 = - ConstantLiteral(&b, *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}})); + ConstantLiteral(&b, LiteralUtil::CreateR3({{{1, 2}, {3, 4}}})); auto r3 = ConstantLiteral( - &b, *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + &b, LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); Add(r3, r1); auto expected = LiteralUtil::CreateR3({{{2, 4}, {6, 8}}, {{6, 8}, {10, 12}}}); - ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); + ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_1) { XlaBuilder b(TestName()); auto r1 = - ConstantLiteral(&b, *LiteralUtil::CreateR3({{{1, 2}}, {{3, 4}}})); + ConstantLiteral(&b, LiteralUtil::CreateR3({{{1, 2}}, {{3, 4}}})); auto r3 = ConstantLiteral( - &b, *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + &b, LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); Add(r3, r1); auto expected = LiteralUtil::CreateR3({{{2, 4}, {4, 6}}, {{8, 10}, {10, 12}}}); - ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); + ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_2) { XlaBuilder b(TestName()); auto r1 = ConstantLiteral( - &b, *LiteralUtil::CreateR3({{{1}, {2}}, {{3}, {4}}})); + &b, LiteralUtil::CreateR3({{{1}, {2}}, {{3}, {4}}})); auto r3 = ConstantLiteral( - &b, *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + &b, LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); Add(r3, r1); auto expected = LiteralUtil::CreateR3({{{2, 3}, {5, 6}}, {{8, 9}, {11, 12}}}); - ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); + ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_1_2) { XlaBuilder b(TestName()); - auto r1 = ConstantLiteral(&b, *LiteralUtil::CreateR3({{{1}}})); + auto r1 = ConstantLiteral(&b, LiteralUtil::CreateR3({{{1}}})); auto r3 = ConstantLiteral( - &b, *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + &b, LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); Add(r3, r1); auto expected = LiteralUtil::CreateR3({{{2, 3}, {4, 5}}, {{6, 7}, {8, 9}}}); - ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); + ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001)); } struct R2ImplicitBroadcastSpec { @@ -616,7 +617,7 @@ XLA_TEST_P(BroadcastR2ImplicitTest, Doit) { auto expected = LiteralUtil::CreateR2FromArray2D(expected_array); ComputeAndCompareLiteral( - &builder, *expected, + &builder, expected, {r2_implicit_global_data1.get(), r2_global_data.get(), r2_implicit_global_data2.get()}, ErrorSpec(1e-6, 1e-6)); @@ -628,65 +629,63 @@ INSTANTIATE_TEST_CASE_P(BroadcastR2ImplicitTestInstances, XLA_TEST_F(BroadcastSimpleTest, Add2DTo2DDegenerate_0) { XlaBuilder b(TestName()); - auto r1 = ConstantLiteral(&b, *LiteralUtil::CreateR2({{1, 2}})); - auto r2 = - ConstantLiteral(&b, *LiteralUtil::CreateR2({{1, 2}, {3, 4}})); + auto r1 = ConstantLiteral(&b, LiteralUtil::CreateR2({{1, 2}})); + auto r2 = ConstantLiteral(&b, LiteralUtil::CreateR2({{1, 2}, {3, 4}})); Add(r2, r1); auto expected = LiteralUtil::CreateR2({{2, 4}, {4, 6}}); - ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); + ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, Add2DTo2DDegenerate_1) { XlaBuilder b(TestName()); - auto r1 = ConstantLiteral(&b, *LiteralUtil::CreateR2({{1}, {2}})); - auto r2 = - ConstantLiteral(&b, *LiteralUtil::CreateR2({{1, 2}, {3, 4}})); + auto r1 = ConstantLiteral(&b, LiteralUtil::CreateR2({{1}, {2}})); + auto r2 = ConstantLiteral(&b, LiteralUtil::CreateR2({{1, 2}, {3, 4}})); Add(r2, r1); auto expected = LiteralUtil::CreateR2({{2, 3}, {5, 6}}); - ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); + ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim0) { XlaBuilder b(TestName()); auto r1 = ConstantR1(&b, {10, 20}); auto r3 = ConstantLiteral( - &b, *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + &b, LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); Add(r3, r1, {0}); auto expected = LiteralUtil::CreateR3( {{{11, 12}, {13, 14}}, {{25, 26}, {27, 28}}}); - ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); + ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim1) { XlaBuilder b(TestName()); auto r1 = ConstantR1(&b, {10, 20}); auto r3 = ConstantLiteral( - &b, *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + &b, LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); Add(r1, r3, {1}); auto expected = LiteralUtil::CreateR3( {{{11, 12}, {23, 24}}, {{15, 16}, {27, 28}}}); - ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); + ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim2) { XlaBuilder b(TestName()); auto r1 = ConstantR1(&b, {10, 20}); auto r3 = ConstantLiteral( - &b, *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + &b, LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); Add(r1, r3, {2}); auto expected = LiteralUtil::CreateR3( {{{11, 22}, {13, 24}}, {{15, 26}, {17, 28}}}); - ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); + ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAll) { @@ -695,7 +694,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAll) { auto r1_1 = ConstantR1(&b, {100, 200}); auto r1_2 = ConstantR1(&b, {10, 20}); auto r3 = ConstantLiteral( - &b, *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + &b, LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); for (int i = 0; i < 3; ++i) { r3 = Add(r1_0, r3, {0}); r3 = Add(r3, r1_1, {1}); @@ -707,7 +706,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAll) { {{{-6 * 1110 - 2, -6 * 1120 - 4}, {-6 * 1210 - 6, -6 * 1220 - 8}}, {{-6 * 2110 - 10, -6 * 2120 - 12}, {-6 * 2210 - 14, -6 * 2220 - 16}}}); - ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); + ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAllWithScalarBroadcast) { @@ -728,7 +727,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAllWithScalarBroadcast) { {{{-3 * 1110 - 3, -3 * 1120 - 3}, {-3 * 1210 - 3, -3 * 1220 - 3}}, {{-3 * 2110 - 3, -3 * 2120 - 3}, {-3 * 2210 - 3, -3 * 2220 - 3}}}); - ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); + ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, InvalidBinaryAndDegenerateBroadcasting) { @@ -737,7 +736,7 @@ XLA_TEST_F(BroadcastSimpleTest, InvalidBinaryAndDegenerateBroadcasting) { XlaBuilder b(TestName()); Add(ConstantR2(&b, {{1.0, 5.0}, {1.0, 5.0}}), - ConstantLiteral(&b, *LiteralUtil::CreateR3( + ConstantLiteral(&b, LiteralUtil::CreateR3( {{{2.0}, {3.0}, {4.0}}, {{5.0}, {6.0}, {7.0}}})), /*broadcast_dimensions=*/{1, 2}); diff --git a/tensorflow/compiler/xla/tests/broadcast_test.cc b/tensorflow/compiler/xla/tests/broadcast_test.cc index 74d4d2eb10c32b270a83aa04dd2e6025d7a56c26..9966e4606ef7f104487182e0240e64e4c9e4d834 100644 --- a/tensorflow/compiler/xla/tests/broadcast_test.cc +++ b/tensorflow/compiler/xla/tests/broadcast_test.cc @@ -46,8 +46,8 @@ XLA_TEST_F(BroadcastTest, BroadcastScalarToScalar) { hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); - EXPECT_TRUE(LiteralTestUtil::Near(*LiteralUtil::CreateR0(42.0), - *result, error_spec_)); + EXPECT_TRUE(LiteralTestUtil::Near(LiteralUtil::CreateR0(42.0), result, + error_spec_)); } XLA_TEST_F(BroadcastTest, BroadcastScalarTo2D) { @@ -63,7 +63,7 @@ XLA_TEST_F(BroadcastTest, BroadcastScalarTo2D) { auto result = ExecuteAndTransfer(std::move(hlo_module), {}); EXPECT_TRUE(LiteralTestUtil::Near( - *LiteralUtil::CreateR2({{42.0, 42.0}, {42.0, 42.0}}), *result, + LiteralUtil::CreateR2({{42.0, 42.0}, {42.0, 42.0}}), result, error_spec_)); } @@ -86,12 +86,12 @@ XLA_TEST_F(BroadcastTest, BroadcastVectorTo2D) { auto result = ExecuteAndTransfer(std::move(hlo_module), {}); EXPECT_TRUE(LiteralTestUtil::Near( - *LiteralUtil::CreateR2({{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}}), - LiteralSlice(*result, {0}), error_spec_)); + LiteralUtil::CreateR2({{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}}), + LiteralSlice(result, {0}), error_spec_)); EXPECT_TRUE(LiteralTestUtil::Near( - *LiteralUtil::CreateR2({{1.0, 2.0, 3.0}, {1.0, 2.0, 3.0}}), - LiteralSlice(*result, {1}), error_spec_)); + LiteralUtil::CreateR2({{1.0, 2.0, 3.0}, {1.0, 2.0, 3.0}}), + LiteralSlice(result, {1}), error_spec_)); } XLA_TEST_F(BroadcastTest, Broadcast2DTo2D) { @@ -107,7 +107,7 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo2D) { auto result = ExecuteAndTransfer(std::move(hlo_module), {}); EXPECT_TRUE(LiteralTestUtil::Near( - *LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}), *result, + LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}), result, error_spec_)); } @@ -126,7 +126,7 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo2DTranspose) { auto result = ExecuteAndTransfer(std::move(hlo_module), {}); EXPECT_TRUE(LiteralTestUtil::Near( - *LiteralUtil::CreateR2({{1.0, 3.0}, {2.0, 4.0}}), *result, + LiteralUtil::CreateR2({{1.0, 3.0}, {2.0, 4.0}}), result, error_spec_)); } @@ -143,9 +143,9 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo3D) { auto result = ExecuteAndTransfer(std::move(hlo_module), {}); EXPECT_TRUE(LiteralTestUtil::Near( - *LiteralUtil::CreateR3({{{1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}}, - {{3.0, 4.0}, {3.0, 4.0}, {3.0, 4.0}}}), - *result, error_spec_)); + LiteralUtil::CreateR3({{{1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}}, + {{3.0, 4.0}, {3.0, 4.0}, {3.0, 4.0}}}), + result, error_spec_)); } TEST_F(BroadcastTest, Broadcast_R1_2_To_R4_2x2x3x3) { @@ -166,9 +166,8 @@ TEST_F(BroadcastTest, Broadcast_R1_2_To_R4_2x2x3x3) { Array2D pz({{1, 2}, {1, 2}}); expected.FillWithPZ(pz); - EXPECT_TRUE( - LiteralTestUtil::Near(*LiteralUtil::CreateR4FromArray4D(expected), - *result, error_spec_)); + EXPECT_TRUE(LiteralTestUtil::Near( + LiteralUtil::CreateR4FromArray4D(expected), result, error_spec_)); } TEST_F(BroadcastTest, Broadcast_R1_1025_To_R4_3x3x3x1025) { @@ -197,9 +196,8 @@ TEST_F(BroadcastTest, Broadcast_R1_1025_To_R4_3x3x3x1025) { } expected.FillWithYX(yx); - EXPECT_TRUE( - LiteralTestUtil::Near(*LiteralUtil::CreateR4FromArray4D(expected), - *result, error_spec_)); + EXPECT_TRUE(LiteralTestUtil::Near( + LiteralUtil::CreateR4FromArray4D(expected), result, error_spec_)); } XLA_TEST_F(BroadcastTest, Broadcast_R1_64_To_R4_32x64x7x7) { @@ -220,8 +218,8 @@ XLA_TEST_F(BroadcastTest, Broadcast_R1_64_To_R4_32x64x7x7) { hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); - EXPECT_TRUE(LiteralTestUtil::Near(*LiteralUtil::CreateR4FromArray4D(r4_array), - *result, error_spec_)); + EXPECT_TRUE(LiteralTestUtil::Near(LiteralUtil::CreateR4FromArray4D(r4_array), + result, error_spec_)); } TEST_F(BroadcastTest, Broadcast_R0_to_R4_64x64x3x3) { @@ -240,9 +238,8 @@ TEST_F(BroadcastTest, Broadcast_R0_to_R4_64x64x3x3) { Array4D expected(64, 64, 3, 3); expected.Fill(1.0f); - EXPECT_TRUE( - LiteralTestUtil::Near(*LiteralUtil::CreateR4FromArray4D(expected), - *result, error_spec_)); + EXPECT_TRUE(LiteralTestUtil::Near( + LiteralUtil::CreateR4FromArray4D(expected), result, error_spec_)); } TEST_F(BroadcastTest, Broadcast_R2_2x2_To_R4_3x3x2x2) { @@ -263,9 +260,8 @@ TEST_F(BroadcastTest, Broadcast_R2_2x2_To_R4_3x3x2x2) { Array4D expected(3, 3, 2, 2); expected.FillWithYX(to_broadcast); - EXPECT_TRUE( - LiteralTestUtil::Near(*LiteralUtil::CreateR4FromArray4D(expected), - *result, error_spec_)); + EXPECT_TRUE(LiteralTestUtil::Near( + LiteralUtil::CreateR4FromArray4D(expected), result, error_spec_)); } TEST_F(BroadcastTest, Broadcast_R3_2x3x4_to_R4_2x3x4x5) { @@ -295,9 +291,8 @@ TEST_F(BroadcastTest, Broadcast_R3_2x3x4_to_R4_2x3x4x5) { hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); - EXPECT_TRUE( - LiteralTestUtil::Near(*LiteralUtil::CreateR4FromArray4D(expected), - *result, error_spec_)); + EXPECT_TRUE(LiteralTestUtil::Near( + LiteralUtil::CreateR4FromArray4D(expected), result, error_spec_)); } } // namespace diff --git a/tensorflow/compiler/xla/tests/build_defs.bzl b/tensorflow/compiler/xla/tests/build_defs.bzl index 53f2c3bfbfce9585cb68f103a495ce2f1ad8432e..05d4d04034bf50c8bb840e59b28a590fce048c19 100644 --- a/tensorflow/compiler/xla/tests/build_defs.bzl +++ b/tensorflow/compiler/xla/tests/build_defs.bzl @@ -3,256 +3,266 @@ load("@local_config_cuda//cuda:build_defs.bzl", "cuda_is_configured") load("//tensorflow/compiler/xla/tests:plugin.bzl", "plugins") load("//tensorflow:tensorflow.bzl", "tf_cc_test") +load( + "//tensorflow/core:platform/default/build_config_root.bzl", + "tf_cuda_tests_tags", +) all_backends = ["cpu", "gpu"] + plugins.keys() def filter_backends(backends): - """Removes "gpu" from a backend list if CUDA is not enabled. - - This allows us to simply hardcode lists including "gpu" here and in the - BUILD file, without causing failures when CUDA isn't enabled.' - - Args: - backends: A list of backends to filter. - - Returns: - The filtered list of backends. - """ - if cuda_is_configured(): - return backends - else: - return [backend for backend in backends if backend != "gpu"] - - -def xla_test(name, - srcs, - deps, - xla_test_library_deps=[], - backends=[], - blacklisted_backends=[], - args=[], - tags=[], - copts=[], - data=[], - backend_tags={}, - backend_args={}, - **kwargs): - """Generates cc_test targets for the given XLA backends. - - This rule generates a cc_test target for one or more XLA backends and also a - platform-agnostic cc_library rule. The arguments are identical to cc_test with - two additions: 'backends' and 'backend_args'. 'backends' specifies the - backends to generate tests for ("cpu", "gpu"), and - 'backend_args'/'backend_tags' specifies backend-specific args parameters to - use when generating the cc_test. - - The name of the cc_tests are the provided name argument with the backend name - appended, and the cc_library target name is the provided name argument with - "_lib" appended. For example, if name parameter is "foo_test", then the cpu - test target will be "foo_test_cpu" and the cc_library target is "foo_lib". - - The cc_library target can be used to link with other plugins outside of - xla_test. - - The build rule also defines a test suite ${name} which includes the tests for - each of the supported backends. - - Each generated cc_test target has a tag indicating which backend the test is - for. This tag is of the form "xla_${BACKEND}" (eg, "xla_cpu"). These - tags can be used to gather tests for a particular backend into a test_suite. - - Examples: - - # Generates the targets: foo_test_cpu and foo_test_gpu. - xla_test( - name = "foo_test", - srcs = ["foo_test.cc"], - backends = ["cpu", "gpu"], - deps = [...], - ) + """Removes "gpu" from a backend list if CUDA is not enabled. - # Generates the targets: bar_test_cpu and bar_test_gpu. bar_test_cpu - # includes the additional arg "--special_cpu_flag". - xla_test( - name = "bar_test", - srcs = ["bar_test.cc"], - backends = ["cpu", "gpu"], - backend_args = {"cpu": ["--special_cpu_flag"]} - deps = [...], - ) + This allows us to simply hardcode lists including "gpu" here and in the + BUILD file, without causing failures when CUDA isn't enabled.' - The build rule defines the preprocessor macro XLA_TEST_BACKEND_${BACKEND} - to the value 1 where ${BACKEND} is the uppercase name of the backend. - - Args: - name: Name of the target. - srcs: Sources for the target. - deps: Dependencies of the target. - xla_test_library_deps: If set, the generated test targets will depend on the - respective cc_libraries generated by the xla_test_library rule. - backends: A list of backends to generate tests for. Supported values: "cpu", - "gpu". If this list is empty, the test will be generated for all supported - backends. - blacklisted_backends: A list of backends to NOT generate tests for. - args: Test arguments for the target. - tags: Tags for the target. - copts: Additional copts to pass to the build. - data: Additional data to pass to the build. - backend_tags: A dict mapping backend name to list of additional tags to - use for that target. - backend_args: A dict mapping backend name to list of additional args to - use for that target. - **kwargs: Additional keyword arguments to pass to native.cc_test. - """ - test_names = [] - if not backends: - backends = all_backends - - backends = [backend for backend in backends - if backend not in blacklisted_backends] - - native.cc_library( - name="%s_lib" % name, - srcs=srcs, - copts=copts, - testonly=True, - deps=deps + ["//tensorflow/compiler/xla/tests:test_macros_header"], - ) - - for backend in filter_backends(backends): - test_name = "%s_%s" % (name, backend) - this_backend_tags = ["xla_%s" % backend] - this_backend_copts = [] - this_backend_args = backend_args.get(backend, []) - this_backend_data = [] - if backend == "cpu": - backend_deps = ["//tensorflow/compiler/xla/service:cpu_plugin"] - backend_deps += ["//tensorflow/compiler/xla/tests:test_macros_cpu"] - elif backend == "gpu": - backend_deps = ["//tensorflow/compiler/xla/service:gpu_plugin"] - backend_deps += ["//tensorflow/compiler/xla/tests:test_macros_gpu"] - this_backend_tags += ["requires-gpu-sm35"] - elif backend in plugins: - backend_deps = [] - backend_deps += plugins[backend]["deps"] - this_backend_copts += plugins[backend]["copts"] - this_backend_tags += plugins[backend]["tags"] - this_backend_args += plugins[backend]["args"] - this_backend_data += plugins[backend]["data"] - else: - fail("Unknown backend %s" % backend) - - if xla_test_library_deps: - for lib_dep in xla_test_library_deps: - backend_deps += ["%s_%s" % (lib_dep, backend)] - - tf_cc_test( - name=test_name, - srcs=srcs, - tags=tags + backend_tags.get(backend, []) + this_backend_tags, - extra_copts=copts + ["-DXLA_TEST_BACKEND_%s=1" % backend.upper()] + - this_backend_copts, - args=args + this_backend_args, - deps=deps + backend_deps, - data=data + this_backend_data, - **kwargs) - - test_names.append(test_name) - - native.test_suite(name=name, tests=test_names) - -def xla_test_library(name, - srcs, - hdrs=[], - deps=[], - backends=[]): - """Generates cc_library targets for the given XLA backends. - - This rule forces the sources to be compiled for each backend so that the - backend specific macros could expand correctly. It's useful when test targets - in different directories referring to the same sources but test with different - arguments. - - Examples: - - # Generates the targets: foo_test_library_cpu and foo_test_gpu. - xla_test_library( - name = "foo_test_library", - srcs = ["foo_test.cc"], - backends = ["cpu", "gpu"], - deps = [...], - ) - # Then use the xla_test rule to generate test targets: - xla_test( - name = "foo_test", - srcs = [], - backends = ["cpu", "gpu"], - deps = [...], - xla_test_library_deps = [":foo_test_library"], - ) + Args: + backends: A list of backends to filter. - Args: - name: Name of the target. - srcs: Sources for the target. - hdrs: Headers for the target. - deps: Dependencies of the target. - backends: A list of backends to generate libraries for. - Supported values: "cpu", "gpu". If this list is empty, the - library will be generated for all supported backends. - """ - - if not backends: - backends = all_backends - - for backend in filter_backends(backends): - this_backend_copts = [] - if backend in ["cpu", "gpu"]: - backend_deps = ["//tensorflow/compiler/xla/tests:test_macros_%s" % backend] - elif backend in plugins: - backend_deps = plugins[backend]["deps"] - this_backend_copts += plugins[backend]["copts"] + Returns: + The filtered list of backends. + """ + if cuda_is_configured(): + return backends else: - fail("Unknown backend %s" % backend) + return [backend for backend in backends if backend != "gpu"] + +def xla_test( + name, + srcs, + deps, + xla_test_library_deps = [], + backends = [], + blacklisted_backends = [], + args = [], + tags = [], + copts = [], + data = [], + backend_tags = {}, + backend_args = {}, + **kwargs): + """Generates cc_test targets for the given XLA backends. + + This rule generates a cc_test target for one or more XLA backends and also a + platform-agnostic cc_library rule. The arguments are identical to cc_test with + two additions: 'backends' and 'backend_args'. 'backends' specifies the + backends to generate tests for ("cpu", "gpu"), and + 'backend_args'/'backend_tags' specifies backend-specific args parameters to + use when generating the cc_test. + + The name of the cc_tests are the provided name argument with the backend name + appended, and the cc_library target name is the provided name argument with + "_lib" appended. For example, if name parameter is "foo_test", then the cpu + test target will be "foo_test_cpu" and the cc_library target is "foo_lib". + + The cc_library target can be used to link with other plugins outside of + xla_test. + + The build rule also defines a test suite ${name} which includes the tests for + each of the supported backends. + + Each generated cc_test target has a tag indicating which backend the test is + for. This tag is of the form "xla_${BACKEND}" (eg, "xla_cpu"). These + tags can be used to gather tests for a particular backend into a test_suite. + + Examples: + + # Generates the targets: foo_test_cpu and foo_test_gpu. + xla_test( + name = "foo_test", + srcs = ["foo_test.cc"], + backends = ["cpu", "gpu"], + deps = [...], + ) + + # Generates the targets: bar_test_cpu and bar_test_gpu. bar_test_cpu + # includes the additional arg "--special_cpu_flag". + xla_test( + name = "bar_test", + srcs = ["bar_test.cc"], + backends = ["cpu", "gpu"], + backend_args = {"cpu": ["--special_cpu_flag"]} + deps = [...], + ) + + The build rule defines the preprocessor macro XLA_TEST_BACKEND_${BACKEND} + to the value 1 where ${BACKEND} is the uppercase name of the backend. + + Args: + name: Name of the target. + srcs: Sources for the target. + deps: Dependencies of the target. + xla_test_library_deps: If set, the generated test targets will depend on the + respective cc_libraries generated by the xla_test_library rule. + backends: A list of backends to generate tests for. Supported values: "cpu", + "gpu". If this list is empty, the test will be generated for all supported + backends. + blacklisted_backends: A list of backends to NOT generate tests for. + args: Test arguments for the target. + tags: Tags for the target. + copts: Additional copts to pass to the build. + data: Additional data to pass to the build. + backend_tags: A dict mapping backend name to list of additional tags to + use for that target. + backend_args: A dict mapping backend name to list of additional args to + use for that target. + **kwargs: Additional keyword arguments to pass to native.cc_test. + """ + test_names = [] + if not backends: + backends = all_backends + + backends = [ + backend + for backend in backends + if backend not in blacklisted_backends + ] native.cc_library( - name = "%s_%s" % (name, backend), + name = "%s_lib" % name, srcs = srcs, + copts = copts, testonly = True, - hdrs = hdrs, - copts = ["-DXLA_TEST_BACKEND_%s=1" % backend.upper()] - + this_backend_copts, - deps = deps + backend_deps, + deps = deps + ["//tensorflow/compiler/xla/tests:test_macros_header"], ) - -def generate_backend_suites(backends=[]): - if not backends: - backends = all_backends - for backend in filter_backends(backends): - native.test_suite(name="%s_tests" % backend, - tags = ["xla_%s" % backend]) - - -def generate_backend_test_macros(backends=[]): - if not backends: - backends = all_backends - for backend in filter_backends(backends): - manifest = "" - if backend in plugins: - manifest = plugins[backend]["disabled_manifest"] - - native.cc_library( - name="test_macros_%s" % backend, - testonly = True, - srcs = ["test_macros.cc"], - hdrs = ["test_macros.h"], - copts = [ - "-DXLA_PLATFORM=\\\"%s\\\"" % backend.upper(), - "-DXLA_DISABLED_MANIFEST=\\\"%s\\\"" % manifest, - ], - deps = [ - "//tensorflow/compiler/xla:types", - "//tensorflow/core:lib", - "//tensorflow/core:regexp_internal", - "//tensorflow/core:test", - ]) + for backend in filter_backends(backends): + test_name = "%s_%s" % (name, backend) + this_backend_tags = ["xla_%s" % backend] + this_backend_copts = [] + this_backend_args = backend_args.get(backend, []) + this_backend_data = [] + if backend == "cpu": + backend_deps = ["//tensorflow/compiler/xla/service:cpu_plugin"] + backend_deps += ["//tensorflow/compiler/xla/tests:test_macros_cpu"] + elif backend == "gpu": + backend_deps = ["//tensorflow/compiler/xla/service:gpu_plugin"] + backend_deps += ["//tensorflow/compiler/xla/tests:test_macros_gpu"] + this_backend_tags += tf_cuda_tests_tags() + elif backend in plugins: + backend_deps = [] + backend_deps += plugins[backend]["deps"] + this_backend_copts += plugins[backend]["copts"] + this_backend_tags += plugins[backend]["tags"] + this_backend_args += plugins[backend]["args"] + this_backend_data += plugins[backend]["data"] + else: + fail("Unknown backend %s" % backend) + + if xla_test_library_deps: + for lib_dep in xla_test_library_deps: + backend_deps += ["%s_%s" % (lib_dep, backend)] + + tf_cc_test( + name = test_name, + srcs = srcs, + tags = tags + backend_tags.get(backend, []) + this_backend_tags, + extra_copts = copts + ["-DXLA_TEST_BACKEND_%s=1" % backend.upper()] + + this_backend_copts, + args = args + this_backend_args, + deps = deps + backend_deps, + data = data + this_backend_data, + **kwargs + ) + + test_names.append(test_name) + + native.test_suite(name = name, tests = test_names) + +def xla_test_library( + name, + srcs, + hdrs = [], + deps = [], + backends = []): + """Generates cc_library targets for the given XLA backends. + + This rule forces the sources to be compiled for each backend so that the + backend specific macros could expand correctly. It's useful when test targets + in different directories referring to the same sources but test with different + arguments. + + Examples: + + # Generates the targets: foo_test_library_cpu and foo_test_gpu. + xla_test_library( + name = "foo_test_library", + srcs = ["foo_test.cc"], + backends = ["cpu", "gpu"], + deps = [...], + ) + # Then use the xla_test rule to generate test targets: + xla_test( + name = "foo_test", + srcs = [], + backends = ["cpu", "gpu"], + deps = [...], + xla_test_library_deps = [":foo_test_library"], + ) + + Args: + name: Name of the target. + srcs: Sources for the target. + hdrs: Headers for the target. + deps: Dependencies of the target. + backends: A list of backends to generate libraries for. + Supported values: "cpu", "gpu". If this list is empty, the + library will be generated for all supported backends. + """ + + if not backends: + backends = all_backends + + for backend in filter_backends(backends): + this_backend_copts = [] + if backend in ["cpu", "gpu"]: + backend_deps = ["//tensorflow/compiler/xla/tests:test_macros_%s" % backend] + elif backend in plugins: + backend_deps = plugins[backend]["deps"] + this_backend_copts += plugins[backend]["copts"] + else: + fail("Unknown backend %s" % backend) + + native.cc_library( + name = "%s_%s" % (name, backend), + srcs = srcs, + testonly = True, + hdrs = hdrs, + copts = ["-DXLA_TEST_BACKEND_%s=1" % backend.upper()] + + this_backend_copts, + deps = deps + backend_deps, + ) + +def generate_backend_suites(backends = []): + if not backends: + backends = all_backends + for backend in filter_backends(backends): + native.test_suite( + name = "%s_tests" % backend, + tags = ["xla_%s" % backend, "-broken", "manual"], + ) + +def generate_backend_test_macros(backends = []): + if not backends: + backends = all_backends + for backend in filter_backends(backends): + manifest = "" + if backend in plugins: + manifest = plugins[backend]["disabled_manifest"] + + native.cc_library( + name = "test_macros_%s" % backend, + testonly = True, + srcs = ["test_macros.cc"], + hdrs = ["test_macros.h"], + copts = [ + "-DXLA_PLATFORM=\\\"%s\\\"" % backend.upper(), + "-DXLA_DISABLED_MANIFEST=\\\"%s\\\"" % manifest, + ], + deps = [ + "//tensorflow/compiler/xla:types", + "//tensorflow/core:lib", + "//tensorflow/core:regexp_internal", + "//tensorflow/core:test", + ], + ) diff --git a/tensorflow/compiler/xla/tests/call_test.cc b/tensorflow/compiler/xla/tests/call_test.cc index b1d18210eaafdfec0920c0cccaa0dfdbd6de5609..8b31e53707eee456e09adfe9fb76f03a8855056d 100644 --- a/tensorflow/compiler/xla/tests/call_test.cc +++ b/tensorflow/compiler/xla/tests/call_test.cc @@ -77,8 +77,7 @@ class CallOpTest : public ClientLibraryTestBase { XLA_TEST_F(CallOpTest, CallR0F32IdentityScalar) { XlaBuilder builder(TestName()); XlaComputation callee = CreateR0F32IdentityComputation(); - auto constant = - ConstantLiteral(&builder, *LiteralUtil::CreateR0(42.0)); + auto constant = ConstantLiteral(&builder, LiteralUtil::CreateR0(42.0)); Call(&builder, callee, {constant}); ComputeAndCompareR0(&builder, 42.0, {}, ErrorSpec(0.01f)); @@ -87,8 +86,8 @@ XLA_TEST_F(CallOpTest, CallR0F32IdentityScalar) { XLA_TEST_F(CallOpTest, CallR1S0F32AddArray) { XlaBuilder builder(TestName()); XlaComputation callee = CreateR1S0F32AdditionComputation(); - auto x = ConstantLiteral(&builder, *LiteralUtil::CreateR1({})); - auto y = ConstantLiteral(&builder, *LiteralUtil::CreateR1({})); + auto x = ConstantLiteral(&builder, LiteralUtil::CreateR1({})); + auto y = ConstantLiteral(&builder, LiteralUtil::CreateR1({})); Call(&builder, callee, {x, y}); ComputeAndCompareR1(&builder, {}, {}, ErrorSpec(0.01f)); @@ -98,9 +97,9 @@ XLA_TEST_F(CallOpTest, CallR1S2F32AddArray) { XlaBuilder builder(TestName()); XlaComputation callee = CreateR1S2F32AdditionComputation(); auto x = - ConstantLiteral(&builder, *LiteralUtil::CreateR1({1.0f, 2.0f})); + ConstantLiteral(&builder, LiteralUtil::CreateR1({1.0f, 2.0f})); auto y = - ConstantLiteral(&builder, *LiteralUtil::CreateR1({2.0f, 3.0f})); + ConstantLiteral(&builder, LiteralUtil::CreateR1({2.0f, 3.0f})); Call(&builder, callee, {x, y}); ComputeAndCompareR1(&builder, {3.0f, 5.0f}, {}, ErrorSpec(0.01f)); @@ -133,7 +132,7 @@ XLA_TEST_F(CallOpTest, CallTreeTwoDeepBranchFactorThree) { TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr start, - client_->TransferToServer(*LiteralUtil::CreateR0(1.0f))); + client_->TransferToServer(LiteralUtil::CreateR0(1.0f))); ComputeAndCompareR0(&builder3, 10.0f, {start.get()}, ErrorSpec(0.0f)); } @@ -141,10 +140,10 @@ XLA_TEST_F(CallOpTest, CallR0F32Tuple) { XlaBuilder builder(TestName()); XlaComputation callee = CreateR0F32TupleComputation(); auto elem = LiteralUtil::CreateR0(42.0); - auto tuple = LiteralUtil::MakeTuple({elem.get()}); - Call(&builder, callee, {ConstantLiteral(&builder, *elem)}); + auto tuple = LiteralUtil::MakeTuple({&elem}); + Call(&builder, callee, {ConstantLiteral(&builder, elem)}); - ComputeAndCompareTuple(&builder, *tuple, {}, ErrorSpec(0.01f)); + ComputeAndCompareTuple(&builder, tuple, {}, ErrorSpec(0.01f)); } } // namespace diff --git a/tensorflow/compiler/xla/tests/check_execution_arity_test.cc b/tensorflow/compiler/xla/tests/check_execution_arity_test.cc index a4eb57fc7b9abd460a7d158d0dc629eba88018cd..2f1510ff6969757f8091e9c043b61cb2a467ccd5 100644 --- a/tensorflow/compiler/xla/tests/check_execution_arity_test.cc +++ b/tensorflow/compiler/xla/tests/check_execution_arity_test.cc @@ -38,14 +38,14 @@ TEST_F(CheckExecutionArityTest, TwoParamComputationNumArguments) { XlaBuilder builder("add_two_params"); auto param_literal = LiteralUtil::CreateR1({1.1f, 2.2f}); - auto p0 = Parameter(&builder, 0, param_literal->shape(), "param0"); - auto p1 = Parameter(&builder, 1, param_literal->shape(), "param1"); + auto p0 = Parameter(&builder, 0, param_literal.shape(), "param0"); + auto p1 = Parameter(&builder, 1, param_literal.shape(), "param1"); Add(p0, p1); auto param0_data = - client_->TransferToServer(*param_literal).ConsumeValueOrDie(); + client_->TransferToServer(param_literal).ConsumeValueOrDie(); auto param1_data = - client_->TransferToServer(*param_literal).ConsumeValueOrDie(); + client_->TransferToServer(param_literal).ConsumeValueOrDie(); auto computation_status = builder.Build(); ASSERT_IS_OK(computation_status.status()); @@ -86,12 +86,12 @@ XLA_TEST_F(CheckExecutionArityTest, CheckArgumentShapes) { auto computation = computation_status.ConsumeValueOrDie(); auto f32_literal = LiteralUtil::CreateR0(1.1f); - auto f32_data = client_->TransferToServer(*f32_literal).ConsumeValueOrDie(); + auto f32_data = client_->TransferToServer(f32_literal).ConsumeValueOrDie(); auto f32_4_literal = LiteralUtil::CreateR1({1.0f, 2.0f, 3.0f, 4.0f}); auto f32_4_data = - client_->TransferToServer(*f32_4_literal).ConsumeValueOrDie(); + client_->TransferToServer(f32_4_literal).ConsumeValueOrDie(); auto u8_4_literal = LiteralUtil::CreateR1U8("hola"); - auto u8_4_data = client_->TransferToServer(*u8_4_literal).ConsumeValueOrDie(); + auto u8_4_data = client_->TransferToServer(u8_4_literal).ConsumeValueOrDie(); // Match auto status = client_->Execute( diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc index 9cd974fd9bbb9f0f9bf316feb1c735106ed2bf07..fbdf0fcb6543f09dedefef55cfe0f8a5d9067d5a 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.cc +++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc @@ -95,15 +95,14 @@ string ClientLibraryTestBase::TestName() const { } StatusOr> ClientLibraryTestBase::Execute( - XlaBuilder* builder, tensorflow::gtl::ArraySlice arguments) { + XlaBuilder* builder, absl::Span arguments) { // Build the computation, as a convenience. TF_ASSIGN_OR_RETURN(auto computation, builder->Build()); return client_->Execute(computation, arguments, &execution_options_); } -StatusOr> ClientLibraryTestBase::ExecuteAndTransfer( - const XlaComputation& computation, - tensorflow::gtl::ArraySlice arguments, +StatusOr ClientLibraryTestBase::ExecuteAndTransfer( + const XlaComputation& computation, absl::Span arguments, const Shape* shape_with_output_layout) { ExecutionOptions execution_options = execution_options_; if (shape_with_output_layout != nullptr) { @@ -114,18 +113,16 @@ StatusOr> ClientLibraryTestBase::ExecuteAndTransfer( &execution_options); } -StatusOr> ClientLibraryTestBase::ExecuteAndTransfer( - XlaBuilder* builder, tensorflow::gtl::ArraySlice arguments, +StatusOr ClientLibraryTestBase::ExecuteAndTransfer( + XlaBuilder* builder, absl::Span arguments, const Shape* shape_with_output_layout) { // Build the computation, as a convenience. TF_ASSIGN_OR_RETURN(auto computation, builder->Build()); return ExecuteAndTransfer(computation, arguments, shape_with_output_layout); } -StatusOr> -ClientLibraryTestBase::ExecuteAndTransferReference( - const XlaComputation& computation, - tensorflow::gtl::ArraySlice arguments, +StatusOr ClientLibraryTestBase::ExecuteAndTransferReference( + const XlaComputation& computation, absl::Span arguments, const Shape* shape_with_output_layout) { ExecutionOptions execution_options = execution_options_; if (shape_with_output_layout != nullptr) { @@ -138,7 +135,7 @@ ClientLibraryTestBase::ExecuteAndTransferReference( } string ClientLibraryTestBase::ExecuteToString( - XlaBuilder* builder, tensorflow::gtl::ArraySlice arguments) { + XlaBuilder* builder, absl::Span arguments) { auto computation_status = builder->Build(); if (!computation_status.ok()) { return computation_status.status().ToString(); @@ -150,29 +147,28 @@ string ClientLibraryTestBase::ExecuteToString( if (!result.ok()) { return result.status().ToString(); } else { - return result.ValueOrDie()->ToString(); + return result.ValueOrDie().ToString(); } } void ClientLibraryTestBase::ComputeAndCompareR1( XlaBuilder* builder, const tensorflow::core::Bitmap& expected, - tensorflow::gtl::ArraySlice arguments) { - std::unique_ptr expected_literal = LiteralUtil::CreateR1(expected); - ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, + absl::Span arguments) { + Literal expected_literal = LiteralUtil::CreateR1(expected); + ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal, arguments); } void ClientLibraryTestBase::ComputeAndCompareLiteral( XlaBuilder* builder, const Literal& expected, - tensorflow::gtl::ArraySlice arguments, - const Shape* shape_with_layout) { + absl::Span arguments, const Shape* shape_with_layout) { EXPECT_IS_OK(ComputeAndCompareLiteralWithStatus(builder, expected, arguments, shape_with_layout)); } void ClientLibraryTestBase::ComputeAndCompareLiteral( XlaBuilder* builder, const Literal& expected, - tensorflow::gtl::ArraySlice arguments, ErrorSpec error, + absl::Span arguments, ErrorSpec error, const Shape* shape_with_layout) { EXPECT_IS_OK(ComputeAndCompareLiteralWithStatus(builder, expected, arguments, error, shape_with_layout)); @@ -180,12 +176,12 @@ void ClientLibraryTestBase::ComputeAndCompareLiteral( Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllOutputLayouts( const xla::XlaComputation& computation, const Literal& expected, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, const std::function& verify_output) { // Try with no layout requirement. TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments)); - verify_output(*actual, ""); + verify_output(actual, ""); // Try with all output layouts. std::vector minor_to_major(ShapeUtil::Rank(expected.shape())); @@ -196,7 +192,7 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllOutputLayouts( AsInt64Slice(expected.shape().dimensions()), minor_to_major); TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments, &layout)); - verify_output(*actual, + verify_output(actual, absl::StrCat("Test with output layout: ", ShapeUtil::HumanStringWithLayout(layout))); } while (std::next_permutation(minor_to_major.begin(), minor_to_major.end())); @@ -205,7 +201,7 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllOutputLayouts( Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts( const xla::XlaComputation& computation, const Literal& /*expected*/, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, const std::function& verify_output, const Shape* output_with_layout) { @@ -221,9 +217,9 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts( TF_ASSIGN_OR_RETURN(auto literal, client_->Transfer(*arguments[index], nullptr)); // Skip tuples because they don't have a rank. - if (ShapeUtil::IsTuple(literal->shape())) { + if (ShapeUtil::IsTuple(literal.shape())) { layout_strings.push_back( - ShapeUtil::HumanStringWithLayout(literal->shape())); + ShapeUtil::HumanStringWithLayout(literal.shape())); arguments_with_layout.push_back(arguments[index]); TF_RETURN_IF_ERROR(choose(index + 1)); arguments_with_layout.pop_back(); @@ -231,15 +227,15 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts( return Status::OK(); } - std::vector minor_to_major(ShapeUtil::Rank(literal->shape())); + std::vector minor_to_major(ShapeUtil::Rank(literal.shape())); std::iota(minor_to_major.begin(), minor_to_major.end(), 0); do { auto literal_relayout = - literal->Relayout(LayoutUtil::MakeLayout(minor_to_major)); + literal.Relayout(LayoutUtil::MakeLayout(minor_to_major)); layout_strings.push_back( - ShapeUtil::HumanStringWithLayout(literal_relayout->shape())); + ShapeUtil::HumanStringWithLayout(literal_relayout.shape())); TF_ASSIGN_OR_RETURN(auto data, - client_->TransferToServer(*literal_relayout)); + client_->TransferToServer(literal_relayout)); arguments_with_layout.push_back(data.get()); TF_RETURN_IF_ERROR(choose(index + 1)); arguments_with_layout.pop_back(); @@ -252,15 +248,14 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts( // Every argument has an assigned layout. TF_ASSIGN_OR_RETURN( auto actual, - ExecuteAndTransfer( - computation, - tensorflow::gtl::ArraySlice(arguments_with_layout), - output_with_layout)); + ExecuteAndTransfer(computation, + absl::Span(arguments_with_layout), + output_with_layout)); string error_message = "Test with input layouts: "; for (const auto& str : layout_strings) { absl::StrAppend(&error_message, str, " "); } - verify_output(*actual, error_message); + verify_output(actual, error_message); return Status::OK(); }; @@ -269,7 +264,7 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts( Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( XlaBuilder* builder, const Literal& expected, - tensorflow::gtl::ArraySlice arguments_passed_in, + absl::Span arguments_passed_in, const Shape* shape_with_layout) { std::vector arguments(arguments_passed_in.begin(), arguments_passed_in.end()); @@ -290,19 +285,15 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( if (ShapeUtil::ElementIsFloating(expected.shape()) || ShapeUtil::ElementIsComplex(expected.shape())) { LOG(WARNING) << "performing exact comparison of floating point numbers"; - } else { - TF_RET_CHECK(ShapeUtil::ElementIsIntegral(expected.shape()) || - expected.shape().element_type() == PRED) - << ShapeUtil::HumanString(expected.shape()); } // We allow using a float expected literal for a bfloat16 output. In this // case, we need to convert the expected literal to bfloat16. const Literal* expected_ptr = &expected; - std::unique_ptr converted_expected; + Literal converted_expected; Shape layout_shape; if (use_bfloat16_) { converted_expected = LiteralUtil::ConvertF32ToBF16(expected); - expected_ptr = converted_expected.get(); + expected_ptr = &converted_expected; if (shape_with_layout != nullptr) { layout_shape = *shape_with_layout; ShapeUtil::ForEachMutableSubshape( @@ -327,14 +318,14 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( } TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments, shape_with_layout)); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected_ptr, *actual)); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected_ptr, actual)); return Status::OK(); } Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( XlaBuilder* builder, const Literal& expected, - tensorflow::gtl::ArraySlice arguments_passed_in, - ErrorSpec error, const Shape* shape_with_layout) { + absl::Span arguments_passed_in, ErrorSpec error, + const Shape* shape_with_layout) { std::vector arguments(arguments_passed_in.begin(), arguments_passed_in.end()); @@ -350,17 +341,15 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( } } - TF_RET_CHECK(ShapeUtil::ElementIsFloating(expected.shape()) || - ShapeUtil::ElementIsComplex(expected.shape())); TF_ASSIGN_OR_RETURN(auto computation, builder->Build()); // We allow using a float expected literal for a bfloat16 output. In this // case, we need to convert the expected literal to bfloat16. const Literal* expected_ptr = &expected; - std::unique_ptr converted_expected; + Literal converted_expected; Shape layout_shape; if (use_bfloat16_) { converted_expected = LiteralUtil::ConvertF32ToBF16(expected); - expected_ptr = converted_expected.get(); + expected_ptr = &converted_expected; if (shape_with_layout != nullptr) { layout_shape = *shape_with_layout; ShapeUtil::ForEachMutableSubshape( @@ -386,13 +375,13 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( } TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments, shape_with_layout)); - EXPECT_TRUE(LiteralTestUtil::Near(*expected_ptr, *actual, error)); + EXPECT_TRUE(LiteralTestUtil::Near(*expected_ptr, actual, error)); return Status::OK(); } void ClientLibraryTestBase::ComputeAndCompareR1U8( XlaBuilder* builder, absl::string_view expected, - tensorflow::gtl::ArraySlice arguments) { + absl::Span arguments) { auto actual_status = ExecuteAndTransfer(builder, arguments); EXPECT_IS_OK(actual_status.status()); if (!actual_status.ok()) { @@ -401,66 +390,65 @@ void ClientLibraryTestBase::ComputeAndCompareR1U8( auto actual = actual_status.ConsumeValueOrDie(); // Turn the expected value into a literal. - std::unique_ptr expected_literal = LiteralUtil::CreateR1U8(expected); + Literal expected_literal = LiteralUtil::CreateR1U8(expected); - VLOG(1) << "expected: " << expected_literal->ToString(); - VLOG(1) << "actual: " << actual->ToString(); + VLOG(1) << "expected: " << expected_literal.ToString(); + VLOG(1) << "actual: " << actual.ToString(); - EXPECT_EQ(expected, actual->GetR1U8AsString()); + EXPECT_EQ(expected, actual.GetR1U8AsString()); } void ClientLibraryTestBase::ComputeAndCompareTuple( XlaBuilder* builder, const Literal& expected, - tensorflow::gtl::ArraySlice arguments) { + absl::Span arguments) { auto actual_status = ExecuteAndTransfer(builder, arguments); EXPECT_IS_OK(actual_status.status()); if (!actual_status.ok()) { return; } auto actual = actual_status.ConsumeValueOrDie(); - EXPECT_TRUE(LiteralTestUtil::Equal(expected, *actual)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, actual)); } void ClientLibraryTestBase::ComputeAndCompareTuple( XlaBuilder* builder, const Literal& expected, - tensorflow::gtl::ArraySlice arguments, ErrorSpec error) { + absl::Span arguments, ErrorSpec error) { auto actual_status = ExecuteAndTransfer(builder, arguments); EXPECT_IS_OK(actual_status.status()); if (!actual_status.ok()) { return; } auto actual = actual_status.ConsumeValueOrDie(); - EXPECT_TRUE(LiteralTestUtil::Near(expected, *actual, error)); + EXPECT_TRUE(LiteralTestUtil::Near(expected, actual, error)); } void ClientLibraryTestBase::ComputeAndCompare( - XlaBuilder* builder, tensorflow::gtl::ArraySlice arguments) { + XlaBuilder* builder, absl::Span arguments) { auto status_or_data = ComputeValueAndReference(builder, arguments); EXPECT_IS_OK(status_or_data); if (!status_or_data.ok()) { return; } - std::unique_ptr reference, result; + Literal reference, result; std::tie(reference, result) = status_or_data.ConsumeValueOrDie(); - EXPECT_TRUE(LiteralTestUtil::Equal(*reference, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(reference, result)); } void ClientLibraryTestBase::ComputeAndCompare( - XlaBuilder* builder, tensorflow::gtl::ArraySlice arguments, - ErrorSpec error) { + XlaBuilder* builder, absl::Span arguments, ErrorSpec error) { auto status_or_data = ComputeValueAndReference(builder, arguments); EXPECT_IS_OK(status_or_data); if (!status_or_data.ok()) { return; } - std::unique_ptr reference, result; + Literal reference, result; std::tie(reference, result) = status_or_data.ConsumeValueOrDie(); - EXPECT_TRUE(LiteralTestUtil::Near(*reference, *result, error)); + EXPECT_TRUE(LiteralTestUtil::Near(reference, result, error)); } -StatusOr, std::unique_ptr>> +StatusOr> ClientLibraryTestBase::ComputeValueAndReference( - XlaBuilder* builder, tensorflow::gtl::ArraySlice arguments) { + XlaBuilder* builder, absl::Span arguments) { // Transfer the arguments to the executor service. We put the unique_ptr's // into a vector to keep the data alive on the service until the end of this // function. @@ -580,8 +568,8 @@ XlaOp ClientLibraryTestBase::AddParam(const Literal& argument, XlaOp ClientLibraryTestBase::CreateConstantFromLiteral(const Literal& literal, XlaBuilder* builder) { return ConstantLiteral(builder, use_bfloat16_ - ? *LiteralUtil::ConvertF32ToBF16(literal) - : literal); + ? LiteralUtil::ConvertF32ToBF16(literal) + : LiteralSlice(literal)); } std::unique_ptr @@ -611,7 +599,7 @@ Shape ClientLibraryTestBase::MaybeConvertShapeToBfloat16(const Shape& shape) { Literal ClientLibraryTestBase::MaybeConvertLiteralToBfloat16( const Literal& literal) { if (use_bfloat16_) { - return std::move(*LiteralUtil::ConvertF32ToBF16(literal)); + return LiteralUtil::ConvertF32ToBF16(literal); } return literal.Clone(); } diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h index ac96d3e325b84a51201158906fe9342df736aec0..9d32f4f5174a57a53a9d3e6477b46fa4de852f7f 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.h +++ b/tensorflow/compiler/xla/tests/client_library_test_base.h @@ -23,6 +23,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/array4d.h" @@ -37,7 +38,6 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/bitmap.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -49,8 +49,8 @@ namespace xla { // use_bfloat16_params with that value. Returns the result. template std::vector ExpandUseBfloat16( - tensorflow::gtl::ArraySlice use_bfloat16_params, - tensorflow::gtl::ArraySlice specs) { + absl::Span use_bfloat16_params, + absl::Span specs) { std::vector expanded; for (bool use_bfloat16 : use_bfloat16_params) { for (const auto& spec : specs) { @@ -93,29 +93,29 @@ class ClientLibraryTestBase : public ::testing::Test { // execution options. Modify execution_options_ in your test if you want to // customize the options. StatusOr> Execute( - XlaBuilder* builder, tensorflow::gtl::ArraySlice arguments); + XlaBuilder* builder, absl::Span arguments); - StatusOr> ExecuteAndTransfer( - XlaBuilder* builder, tensorflow::gtl::ArraySlice arguments, + StatusOr ExecuteAndTransfer( + XlaBuilder* builder, absl::Span arguments, const Shape* shape_with_output_layout = nullptr); - StatusOr> ExecuteAndTransfer( + StatusOr ExecuteAndTransfer( const XlaComputation& computation, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, const Shape* shape_with_output_layout = nullptr); // This executes the computation via the reference client (which connects a // interpreter backend). The result is used as the expected values of the // computation. - StatusOr> ExecuteAndTransferReference( + StatusOr ExecuteAndTransferReference( const XlaComputation& computation, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, const Shape* shape_with_output_layout = nullptr); // Run a computation and return its value as a string. If an error // occurs, then instead return the error as a string. string ExecuteToString(XlaBuilder* builder, - tensorflow::gtl::ArraySlice arguments); + absl::Span arguments); // Convenience methods for building and running a computation, transferring // the result, and comparing it to the expected value(s). Methods are @@ -125,102 +125,98 @@ class ClientLibraryTestBase : public ::testing::Test { // for integral types without the ErrorSpec parameter. template void ComputeAndCompareR0(XlaBuilder* builder, NativeT expected, - tensorflow::gtl::ArraySlice arguments); + absl::Span arguments); template void ComputeAndCompareR0(XlaBuilder* builder, NativeT expected, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, ErrorSpec error); template void ComputeAndCompareR1(XlaBuilder* builder, - tensorflow::gtl::ArraySlice expected, - tensorflow::gtl::ArraySlice arguments); + absl::Span expected, + absl::Span arguments); template void ComputeAndCompareR1(XlaBuilder* builder, - tensorflow::gtl::ArraySlice expected, - tensorflow::gtl::ArraySlice arguments, + absl::Span expected, + absl::Span arguments, ErrorSpec error); // As above, but uses a bitmap to hold the predicate vector to avoid // deficiencies of vector. void ComputeAndCompareR1(XlaBuilder* builder, const tensorflow::core::Bitmap& expected, - tensorflow::gtl::ArraySlice arguments); + absl::Span arguments); template void ComputeAndCompareR2(XlaBuilder* builder, const Array2D& expected, - tensorflow::gtl::ArraySlice arguments); + absl::Span arguments); template void ComputeAndCompareR2(XlaBuilder* builder, const Array2D& expected, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, ErrorSpec error); template void ComputeAndCompareR3(XlaBuilder* builder, const Array3D& expected, - tensorflow::gtl::ArraySlice arguments); + absl::Span arguments); template void ComputeAndCompareR3(XlaBuilder* builder, const Array3D& expected, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, ErrorSpec error); template void ComputeAndCompareR4(XlaBuilder* builder, const Array4D& expected, - tensorflow::gtl::ArraySlice arguments); + absl::Span arguments); template void ComputeAndCompareR4(XlaBuilder* builder, const Array4D& expected, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, ErrorSpec error); // Build and run the computation and compare the result with the given // literal. shape_with_layout indicates the result layout to request when // calling Execute. - void ComputeAndCompareLiteral( - XlaBuilder* builder, const Literal& expected, - tensorflow::gtl::ArraySlice arguments, - const Shape* shape_with_layout = nullptr); - void ComputeAndCompareLiteral( - XlaBuilder* builder, const Literal& expected, - tensorflow::gtl::ArraySlice arguments, ErrorSpec error, - const Shape* shape_with_layout = nullptr); + void ComputeAndCompareLiteral(XlaBuilder* builder, const Literal& expected, + absl::Span arguments, + const Shape* shape_with_layout = nullptr); + void ComputeAndCompareLiteral(XlaBuilder* builder, const Literal& expected, + absl::Span arguments, + ErrorSpec error, + const Shape* shape_with_layout = nullptr); // ComputeAndCompare variant which returns an error status. Status ComputeAndCompareLiteralWithStatus( XlaBuilder* builder, const Literal& expected, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, const Shape* shape_with_layout = nullptr); Status ComputeAndCompareLiteralWithStatus( XlaBuilder* builder, const Literal& expected, - tensorflow::gtl::ArraySlice arguments, ErrorSpec error, + absl::Span arguments, ErrorSpec error, const Shape* shape_with_layout = nullptr); // Compare the result of the computation to a strings. In XLA strings are // represented using rank-1 U8 shapes. - void ComputeAndCompareR1U8( - XlaBuilder* builder, absl::string_view expected, - tensorflow::gtl::ArraySlice arguments); + void ComputeAndCompareR1U8(XlaBuilder* builder, absl::string_view expected, + absl::Span arguments); // Convenience method for running a built computation, transferring the // result, and comparing it to the expected tuple literal. - void ComputeAndCompareTuple( - XlaBuilder* builder, const Literal& expected, - tensorflow::gtl::ArraySlice arguments); - void ComputeAndCompareTuple( - XlaBuilder* builder, const Literal& expected, - tensorflow::gtl::ArraySlice arguments, ErrorSpec error); + void ComputeAndCompareTuple(XlaBuilder* builder, const Literal& expected, + absl::Span arguments); + void ComputeAndCompareTuple(XlaBuilder* builder, const Literal& expected, + absl::Span arguments, + ErrorSpec error); // Convenience method for running a built computation and comparing the result // with the reference result. void ComputeAndCompare(XlaBuilder* builder, - tensorflow::gtl::ArraySlice arguments); + absl::Span arguments); void ComputeAndCompare(XlaBuilder* builder, - tensorflow::gtl::ArraySlice arguments, - ErrorSpec error); + absl::Span arguments, ErrorSpec error); // Create scalar operations for use in reductions. XlaComputation CreateScalarRelu(); @@ -286,7 +282,7 @@ class ClientLibraryTestBase : public ::testing::Test { template XlaOp AddParam(const Array& argument, XlaBuilder* builder) { - return AddParam(*LiteralUtil::CreateFromArray(argument), builder); + return AddParam(LiteralUtil::CreateFromArray(argument), builder); } // Creates a constant instruction with the given literal. When the @@ -301,14 +297,14 @@ class ClientLibraryTestBase : public ::testing::Test { template XlaOp CreateConstantFromArray(const Array& array, XlaBuilder* builder) { - return CreateConstantFromLiteral(*LiteralUtil::CreateFromArray(array), + return CreateConstantFromLiteral(LiteralUtil::CreateFromArray(array), builder); } // Same as CreateConstantFromArray, but for scalars. template XlaOp CreateConstantFromScalar(NativeT value, XlaBuilder* builder) { - return CreateConstantFromLiteral(*LiteralUtil::CreateR0(value), + return CreateConstantFromLiteral(LiteralUtil::CreateR0(value), builder); } @@ -337,7 +333,7 @@ class ClientLibraryTestBase : public ::testing::Test { // converted to bfloat16. template std::unique_ptr CreateR1Parameter( - tensorflow::gtl::ArraySlice values, int64 parameter_number, + absl::Span values, int64 parameter_number, const string& name, XlaBuilder* builder, XlaOp* data_handle); // Creates a parameter instruction that wraps the given constant array @@ -379,9 +375,8 @@ class ClientLibraryTestBase : public ::testing::Test { // Executes the computation and calculates the expected reference value using // the reference client. Returns two literals in the order of (expected, // actual). - StatusOr, std::unique_ptr>> - ComputeValueAndReference(XlaBuilder* builder, - tensorflow::gtl::ArraySlice arguments); + StatusOr> ComputeValueAndReference( + XlaBuilder* builder, absl::Span arguments); Client* client_; Client* ref_client_; // To compute reference result. @@ -390,12 +385,12 @@ class ClientLibraryTestBase : public ::testing::Test { private: Status ComputeAndCompareLiteralWithAllOutputLayouts( const xla::XlaComputation& computation, const Literal& expected, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, const std::function& verify_output); Status ComputeAndCompareLiteralWithAllInputLayouts( const xla::XlaComputation& computation, const Literal& expected, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, const std::function& verify_output, const Shape* output_with_layout = nullptr); @@ -415,130 +410,126 @@ class ClientLibraryTestBase : public ::testing::Test { template void ClientLibraryTestBase::ComputeAndCompareR0( XlaBuilder* builder, NativeT expected, - tensorflow::gtl::ArraySlice arguments) { - std::unique_ptr expected_literal = - LiteralUtil::CreateR0(expected); - ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, + absl::Span arguments) { + Literal expected_literal = LiteralUtil::CreateR0(expected); + ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal, arguments); } template void ClientLibraryTestBase::ComputeAndCompareR0( XlaBuilder* builder, NativeT expected, - tensorflow::gtl::ArraySlice arguments, ErrorSpec error) { + absl::Span arguments, ErrorSpec error) { static_assert(std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value, "Float or complex type required when specifying an ErrorSpec"); - std::unique_ptr expected_literal = - LiteralUtil::CreateR0(expected); - ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, + Literal expected_literal = LiteralUtil::CreateR0(expected); + ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal, arguments, error); } template void ClientLibraryTestBase::ComputeAndCompareR1( - XlaBuilder* builder, tensorflow::gtl::ArraySlice expected, - tensorflow::gtl::ArraySlice arguments) { - std::unique_ptr expected_literal = - LiteralUtil::CreateR1(expected); - ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, + XlaBuilder* builder, absl::Span expected, + absl::Span arguments) { + Literal expected_literal = LiteralUtil::CreateR1(expected); + ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal, arguments); } template void ClientLibraryTestBase::ComputeAndCompareR1( - XlaBuilder* builder, tensorflow::gtl::ArraySlice expected, - tensorflow::gtl::ArraySlice arguments, ErrorSpec error) { + XlaBuilder* builder, absl::Span expected, + absl::Span arguments, ErrorSpec error) { static_assert(std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value, "Float or complex type required when specifying an ErrorSpec"); - std::unique_ptr expected_literal = - LiteralUtil::CreateR1(expected); - ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, + Literal expected_literal = LiteralUtil::CreateR1(expected); + ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal, arguments, error); } template void ClientLibraryTestBase::ComputeAndCompareR2( XlaBuilder* builder, const Array2D& expected, - tensorflow::gtl::ArraySlice arguments) { - std::unique_ptr expected_literal = + absl::Span arguments) { + Literal expected_literal = LiteralUtil::CreateR2FromArray2D(expected); - ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, + ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal, arguments); } template void ClientLibraryTestBase::ComputeAndCompareR2( XlaBuilder* builder, const Array2D& expected, - tensorflow::gtl::ArraySlice arguments, ErrorSpec error) { + absl::Span arguments, ErrorSpec error) { static_assert(std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value, "Float or complex type required when specifying an ErrorSpec"); - std::unique_ptr expected_literal = + Literal expected_literal = LiteralUtil::CreateR2FromArray2D(expected); - ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, + ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal, arguments, error); } template void ClientLibraryTestBase::ComputeAndCompareR3( XlaBuilder* builder, const Array3D& expected, - tensorflow::gtl::ArraySlice arguments) { - std::unique_ptr expected_literal = + absl::Span arguments) { + Literal expected_literal = LiteralUtil::CreateR3FromArray3D(expected); - ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, + ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal, arguments); } template void ClientLibraryTestBase::ComputeAndCompareR3( XlaBuilder* builder, const Array3D& expected, - tensorflow::gtl::ArraySlice arguments, ErrorSpec error) { + absl::Span arguments, ErrorSpec error) { static_assert(std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value, "Float or complex type required when specifying an ErrorSpec"); - std::unique_ptr expected_literal = + Literal expected_literal = LiteralUtil::CreateR3FromArray3D(expected); - ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, + ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal, arguments, error); } template void ClientLibraryTestBase::ComputeAndCompareR4( XlaBuilder* builder, const Array4D& expected, - tensorflow::gtl::ArraySlice arguments) { - std::unique_ptr expected_literal = + absl::Span arguments) { + Literal expected_literal = LiteralUtil::CreateR4FromArray4D(expected); - ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, + ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal, arguments); } template void ClientLibraryTestBase::ComputeAndCompareR4( XlaBuilder* builder, const Array4D& expected, - tensorflow::gtl::ArraySlice arguments, ErrorSpec error) { + absl::Span arguments, ErrorSpec error) { static_assert(std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value, "Float or complex type required when specifying an ErrorSpec"); - std::unique_ptr expected_literal = + Literal expected_literal = LiteralUtil::CreateR4FromArray4D(expected); - ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, + ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal, arguments, error); } @@ -546,27 +537,27 @@ template std::unique_ptr ClientLibraryTestBase::CreateR0Parameter( NativeT value, int64 parameter_number, const string& name, XlaBuilder* builder, XlaOp* data_handle) { - std::unique_ptr literal = LiteralUtil::CreateR0(value); - if (use_bfloat16_ && literal->shape().element_type() == F32) { - literal = LiteralUtil::ConvertF32ToBF16(*literal); + Literal literal = LiteralUtil::CreateR0(value); + if (use_bfloat16_ && literal.shape().element_type() == F32) { + literal = LiteralUtil::ConvertF32ToBF16(literal); } std::unique_ptr data = - client_->TransferToServer(*literal).ConsumeValueOrDie(); - *data_handle = Parameter(builder, parameter_number, literal->shape(), name); + client_->TransferToServer(literal).ConsumeValueOrDie(); + *data_handle = Parameter(builder, parameter_number, literal.shape(), name); return data; } template std::unique_ptr ClientLibraryTestBase::CreateR1Parameter( - tensorflow::gtl::ArraySlice values, int64 parameter_number, + absl::Span values, int64 parameter_number, const string& name, XlaBuilder* builder, XlaOp* data_handle) { - std::unique_ptr literal = LiteralUtil::CreateR1(values); - if (use_bfloat16_ && literal->shape().element_type() == F32) { - literal = LiteralUtil::ConvertF32ToBF16(*literal); + Literal literal = LiteralUtil::CreateR1(values); + if (use_bfloat16_ && literal.shape().element_type() == F32) { + literal = LiteralUtil::ConvertF32ToBF16(literal); } std::unique_ptr data = - client_->TransferToServer(*literal).ConsumeValueOrDie(); - *data_handle = Parameter(builder, parameter_number, literal->shape(), name); + client_->TransferToServer(literal).ConsumeValueOrDie(); + *data_handle = Parameter(builder, parameter_number, literal.shape(), name); return data; } @@ -574,13 +565,13 @@ template std::unique_ptr ClientLibraryTestBase::CreateR2Parameter( const Array2D& array_2d, int64 parameter_number, const string& name, XlaBuilder* builder, XlaOp* data_handle) { - std::unique_ptr literal = LiteralUtil::CreateR2FromArray2D(array_2d); - if (use_bfloat16_ && literal->shape().element_type() == F32) { - literal = LiteralUtil::ConvertF32ToBF16(*literal); + Literal literal = LiteralUtil::CreateR2FromArray2D(array_2d); + if (use_bfloat16_ && literal.shape().element_type() == F32) { + literal = LiteralUtil::ConvertF32ToBF16(literal); } std::unique_ptr data = - client_->TransferToServer(*literal).ConsumeValueOrDie(); - *data_handle = Parameter(builder, parameter_number, literal->shape(), name); + client_->TransferToServer(literal).ConsumeValueOrDie(); + *data_handle = Parameter(builder, parameter_number, literal.shape(), name); return data; } @@ -588,13 +579,13 @@ template std::unique_ptr ClientLibraryTestBase::CreateR3Parameter( const Array3D& array_3d, int64 parameter_number, const string& name, XlaBuilder* builder, XlaOp* data_handle) { - std::unique_ptr literal = LiteralUtil::CreateR3FromArray3D(array_3d); - if (use_bfloat16_ && literal->shape().element_type() == F32) { - literal = LiteralUtil::ConvertF32ToBF16(*literal); + Literal literal = LiteralUtil::CreateR3FromArray3D(array_3d); + if (use_bfloat16_ && literal.shape().element_type() == F32) { + literal = LiteralUtil::ConvertF32ToBF16(literal); } std::unique_ptr data = - client_->TransferToServer(*literal).ConsumeValueOrDie(); - *data_handle = Parameter(builder, parameter_number, literal->shape(), name); + client_->TransferToServer(literal).ConsumeValueOrDie(); + *data_handle = Parameter(builder, parameter_number, literal.shape(), name); return data; } diff --git a/tensorflow/compiler/xla/tests/client_test.cc b/tensorflow/compiler/xla/tests/client_test.cc index c898dacf489db97223e2918414daf5de88bece64..6f2ca84bb646e88af221ab80b727911ff7d990eb 100644 --- a/tensorflow/compiler/xla/tests/client_test.cc +++ b/tensorflow/compiler/xla/tests/client_test.cc @@ -55,16 +55,15 @@ XLA_TEST_F(ClientTest, ExecuteWithLayout) { std::unique_ptr data, client_->Execute(computation, {}, &execution_options)); - std::unique_ptr expected_literal = - LiteralUtil::CreateR2WithLayout( - {{11, 22}, {33, 44}}, LayoutUtil::MakeLayout(transfer_layout)); + Literal expected_literal = LiteralUtil::CreateR2WithLayout( + {{11, 22}, {33, 44}}, LayoutUtil::MakeLayout(transfer_layout)); TF_ASSERT_OK_AND_ASSIGN( - auto computed, client_->Transfer(*data, &expected_literal->shape())); + auto computed, client_->Transfer(*data, &expected_literal.shape())); ASSERT_TRUE(LiteralTestUtil::EqualShapesAndLayouts( - expected_literal->shape(), computed->shape())); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed)); + expected_literal.shape(), computed.shape())); + EXPECT_TRUE(LiteralTestUtil::Equal(expected_literal, computed)); } } } @@ -91,19 +90,19 @@ XLA_TEST_F(ClientTest, ExecuteWithTupleLayout) { auto result, client_->ExecuteAndTransfer(computation, {}, &execution_options)); LiteralTestUtil::ExpectR2Equal({{1, 2}, {3, 4}}, - LiteralSlice(*result, {0})); + LiteralSlice(result, {0})); LiteralTestUtil::ExpectR2Equal({{10, 20}, {30, 40}}, - LiteralSlice(*result, {1})); + LiteralSlice(result, {1})); - EXPECT_TRUE(ShapeUtil::IsTuple(result->shape())); - EXPECT_EQ(2, ShapeUtil::TupleElementCount(result->shape())); + EXPECT_TRUE(ShapeUtil::IsTuple(result.shape())); + EXPECT_EQ(2, ShapeUtil::TupleElementCount(result.shape())); EXPECT_TRUE(ShapeUtil::Equal( - ShapeUtil::GetTupleElementShape(result->shape(), 0), + ShapeUtil::GetTupleElementShape(result.shape(), 0), ShapeUtil::MakeShapeWithLayout(S32, /*dimensions=*/{2, 2}, /*minor_to_major=*/{0, 1}))); EXPECT_TRUE(ShapeUtil::Equal( - ShapeUtil::GetTupleElementShape(result->shape(), 1), + ShapeUtil::GetTupleElementShape(result.shape(), 1), ShapeUtil::MakeShapeWithLayout(S32, /*dimensions=*/{2, 2}, /*minor_to_major=*/{1, 0}))); } @@ -114,7 +113,7 @@ XLA_TEST_F(ClientTest, DISABLED_ON_GPU(ExecuteParallel)) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr const_arg, client_->TransferToServer( - *LiteralUtil::CreateR2({{5, 6}, {7, 8}}))); + LiteralUtil::CreateR2({{5, 6}, {7, 8}}))); XlaBuilder b(TestName() + ".add"); Add(Parameter(&b, 0, shape, "param_0"), @@ -140,9 +139,9 @@ XLA_TEST_F(ClientTest, DISABLED_ON_GPU(ExecuteParallel)) { TF_ASSERT_OK_AND_ASSIGN( auto result_literal, - client_->Transfer(*results[0], &expected_result->shape())); + client_->Transfer(*results[0], &expected_result.shape())); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected_result, *result_literal)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected_result, result_literal)); } } // namespace diff --git a/tensorflow/compiler/xla/tests/codegen_test_base.cc b/tensorflow/compiler/xla/tests/codegen_test_base.cc index 022641394f113ef28e7c53058385d77572822213..fbebe0408730f2fb37aa57a0f19291bbaa3826f9 100644 --- a/tensorflow/compiler/xla/tests/codegen_test_base.cc +++ b/tensorflow/compiler/xla/tests/codegen_test_base.cc @@ -32,11 +32,10 @@ StatusOr> CodegenTestBase::CompileToAotCompilationResult( std::unique_ptr hlo_module, const AotCompilationOptions& options) { - std::vector> hlo_modules; - hlo_modules.push_back(std::move(hlo_module)); + auto module_group = absl::make_unique(std::move(hlo_module)); TF_ASSIGN_OR_RETURN( std::vector> results, - backend().compiler()->CompileAheadOfTime(std::move(hlo_modules), + backend().compiler()->CompileAheadOfTime(std::move(module_group), options)); return std::move(results.front()); } diff --git a/tensorflow/compiler/xla/tests/compilation_cache_test.cc b/tensorflow/compiler/xla/tests/compilation_cache_test.cc index 7c52c9fbbb57f9291ea9f0966e2efa715819fb67..6ef7ca035f75966bef12c7abcb55cb59e9b73655 100644 --- a/tensorflow/compiler/xla/tests/compilation_cache_test.cc +++ b/tensorflow/compiler/xla/tests/compilation_cache_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_builder.h" @@ -30,7 +31,6 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/xla.pb.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/test.h" namespace xla { @@ -38,25 +38,24 @@ namespace { class CompilationCacheTest : public ClientLibraryTestBase { public: - void ExecuteComputationR0F32( - const XlaComputation& computation, - tensorflow::gtl::ArraySlice arguments, float expected_result, - bool expect_cache_hit) { + void ExecuteComputationR0F32(const XlaComputation& computation, + absl::Span arguments, + float expected_result, bool expect_cache_hit) { ExecutionProfile execution_profile; - std::unique_ptr result = + Literal result = client_ ->ExecuteAndTransfer(computation, arguments, /*execution_options=*/&execution_options_, &execution_profile) .ConsumeValueOrDie(); EXPECT_TRUE(LiteralTestUtil::Near( - *LiteralUtil::CreateR0(expected_result), *result, error_spec_)); + LiteralUtil::CreateR0(expected_result), result, error_spec_)); EXPECT_EQ(expect_cache_hit, execution_profile.compilation_cache_hit()); } void ExecuteComputationR2F32( const XlaComputation& computation, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, std::initializer_list> expected_result, bool expect_cache_hit) { ExecutionProfile execution_profile; @@ -64,10 +63,9 @@ class CompilationCacheTest : public ClientLibraryTestBase { ->Execute(computation, arguments, &execution_options_, &execution_profile) .ConsumeValueOrDie(); - std::unique_ptr result = - client_->Transfer(*data_handle).ConsumeValueOrDie(); + Literal result = client_->Transfer(*data_handle).ConsumeValueOrDie(); EXPECT_TRUE(LiteralTestUtil::Near( - *LiteralUtil::CreateR2(expected_result), *result, error_spec_)); + LiteralUtil::CreateR2(expected_result), result, error_spec_)); EXPECT_EQ(expect_cache_hit, execution_profile.compilation_cache_hit()); } @@ -89,13 +87,13 @@ XLA_TEST_F(CompilationCacheTest, DISABLED_ComputationCalledMultipleTimes) { XLA_TEST_F(CompilationCacheTest, DISABLED_ComputationCalledWithDifferentParameters) { std::unique_ptr data_42 = - client_->TransferToServer(*LiteralUtil::CreateR0(42.0f)) + client_->TransferToServer(LiteralUtil::CreateR0(42.0f)) .ConsumeValueOrDie(); std::unique_ptr data_123 = - client_->TransferToServer(*LiteralUtil::CreateR0(123.0f)) + client_->TransferToServer(LiteralUtil::CreateR0(123.0f)) .ConsumeValueOrDie(); std::unique_ptr data_456 = - client_->TransferToServer(*LiteralUtil::CreateR0(456.0f)) + client_->TransferToServer(LiteralUtil::CreateR0(456.0f)) .ConsumeValueOrDie(); XlaBuilder builder(TestName()); @@ -146,12 +144,12 @@ XLA_TEST_F(CompilationCacheTest, DISABLED_DifferentParameterLayouts) { auto rowmaj_array = LiteralUtil::CreateR2WithLayout( {{1.0f, 2.0f}, {3.0f, 4.0f}}, LayoutUtil::MakeLayout({1, 0})); auto rowmaj_handle = - client_->TransferToServer(*rowmaj_array).ConsumeValueOrDie(); + client_->TransferToServer(rowmaj_array).ConsumeValueOrDie(); auto colmaj_array = LiteralUtil::CreateR2WithLayout( {{1.0f, 2.0f}, {3.0f, 4.0f}}, LayoutUtil::MakeLayout({0, 1})); auto colmaj_handle = - client_->TransferToServer(*colmaj_array).ConsumeValueOrDie(); + client_->TransferToServer(colmaj_array).ConsumeValueOrDie(); XlaBuilder builder(TestName()); Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2, 2}), "param0"); diff --git a/tensorflow/compiler/xla/tests/compute_constant_test.cc b/tensorflow/compiler/xla/tests/compute_constant_test.cc index 8226b6de3f780197bc0f1145b617dba99803927f..3b0414a6045a7c5f4f75948d8ccf2775c575626e 100644 --- a/tensorflow/compiler/xla/tests/compute_constant_test.cc +++ b/tensorflow/compiler/xla/tests/compute_constant_test.cc @@ -69,9 +69,9 @@ class ComputeConstantTest : public ::testing::Test { LOG(FATAL) << "invalid client_type value"; } - StatusOr> ComputeConstantLiteral( - Client* client, const XlaOp& operand, XlaBuilder* builder, - Layout* output_layout = nullptr) { + StatusOr ComputeConstantLiteral(Client* client, const XlaOp& operand, + XlaBuilder* builder, + Layout* output_layout = nullptr) { TF_ASSIGN_OR_RETURN(auto subgraph, builder->BuildConstantSubGraph(operand)); TF_ASSIGN_OR_RETURN(auto computed, client->ComputeConstant(subgraph, output_layout)); @@ -83,7 +83,7 @@ class ComputeConstantTest : public ::testing::Test { XlaBuilder* builder) { TF_ASSIGN_OR_RETURN(auto literal, ComputeConstantLiteral(client, operand, builder, nullptr)); - return literal->Get({}); + return literal.Get({}); } bool IsConstant(const XlaOp& operand, XlaBuilder* builder) { @@ -206,9 +206,8 @@ TEST_F(ComputeConstantTest, NonScalarAdd) { TF_ASSERT_OK_AND_ASSIGN(auto computed, ComputeConstantLiteral(client, computation, &b)); - std::unique_ptr expected_literal = - LiteralUtil::CreateR1({4, 6}); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed)); + Literal expected_literal = LiteralUtil::CreateR1({4, 6}); + EXPECT_TRUE(LiteralTestUtil::Equal(expected_literal, computed)); } } @@ -221,8 +220,8 @@ TEST_F(ComputeConstantTest, IntegerDivide) { TF_ASSERT_OK_AND_ASSIGN(auto computed, ComputeConstantLiteral(client, computation, &b)); - std::unique_ptr expected_literal = LiteralUtil::CreateR0(5); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed)); + Literal expected_literal = LiteralUtil::CreateR0(5); + EXPECT_TRUE(LiteralTestUtil::Equal(expected_literal, computed)); } } @@ -241,12 +240,11 @@ XLA_TEST_F(ComputeConstantTest, Layout) { ConstantR2(&b, {{10, 20}, {30, 40}})), &b, &layout_proto)); - std::unique_ptr expected_literal = - LiteralUtil::CreateR2WithLayout( - {{11, 22}, {33, 44}}, LayoutUtil::MakeLayout(layout)); + Literal expected_literal = LiteralUtil::CreateR2WithLayout( + {{11, 22}, {33, 44}}, LayoutUtil::MakeLayout(layout)); ASSERT_TRUE(LiteralTestUtil::EqualShapesAndLayouts( - expected_literal->shape(), computed->shape())); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed)); + expected_literal.shape(), computed.shape())); + EXPECT_TRUE(LiteralTestUtil::Equal(expected_literal, computed)); } } } diff --git a/tensorflow/compiler/xla/tests/concat_test.cc b/tensorflow/compiler/xla/tests/concat_test.cc index be017477d84eb9faf5aa79dcdf54d6b6aaf6fd8e..9811a015e91d866d6f4de6ebb6dac536ed6c7e06 100644 --- a/tensorflow/compiler/xla/tests/concat_test.cc +++ b/tensorflow/compiler/xla/tests/concat_test.cc @@ -536,8 +536,8 @@ XLA_TEST_F(ConcatTest, ConcatOperandsOfSameOperand) { auto f32_scalar = ShapeUtil::MakeShape(xla::F32, {}); auto x_literal = LiteralUtil::CreateR0(2.f); auto y_literal = LiteralUtil::CreateR0(3.f); - auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie(); - auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie(); + auto x_data = client_->TransferToServer(x_literal).ConsumeValueOrDie(); + auto y_data = client_->TransferToServer(y_literal).ConsumeValueOrDie(); XlaBuilder builder(TestName()); auto x = Parameter(&builder, 0, f32_scalar, "x"); @@ -559,12 +559,12 @@ XLA_TEST_F(ConcatTest, ConcatBroadcastArgument) { auto x_literal = LiteralUtil::CreateR1({2.0f, 3.0f, 5.0f, 6.0f}); auto y_literal = LiteralUtil::CreateR0(1.5f); auto z_literal = LiteralUtil::CreateR0(5.5f); - auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie(); - auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie(); - auto z_data = client_->TransferToServer(*z_literal).ConsumeValueOrDie(); + auto x_data = client_->TransferToServer(x_literal).ConsumeValueOrDie(); + auto y_data = client_->TransferToServer(y_literal).ConsumeValueOrDie(); + auto z_data = client_->TransferToServer(z_literal).ConsumeValueOrDie(); XlaBuilder builder(TestName()); - auto x = Parameter(&builder, 0, x_literal->shape(), "x"); + auto x = Parameter(&builder, 0, x_literal.shape(), "x"); auto y = Parameter(&builder, 1, f32_scalar, "y"); auto z = Parameter(&builder, 2, f32_scalar, "z"); auto bcast = Broadcast(y, {5}); @@ -587,12 +587,12 @@ XLA_TEST_F(ConcatTest, ConcatBroadcastArgumentR3) { auto x_literal = LiteralUtil::CreateR3FromArray3D(x3d); auto y_literal = LiteralUtil::CreateR0(1.5f); auto z_literal = LiteralUtil::CreateR0(5.5f); - auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie(); - auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie(); - auto z_data = client_->TransferToServer(*z_literal).ConsumeValueOrDie(); + auto x_data = client_->TransferToServer(x_literal).ConsumeValueOrDie(); + auto y_data = client_->TransferToServer(y_literal).ConsumeValueOrDie(); + auto z_data = client_->TransferToServer(z_literal).ConsumeValueOrDie(); XlaBuilder builder(TestName()); - auto x = Parameter(&builder, 0, x_literal->shape(), "x"); + auto x = Parameter(&builder, 0, x_literal.shape(), "x"); auto y = Parameter(&builder, 1, f32_scalar, "y"); auto z = Parameter(&builder, 2, f32_scalar, "y"); auto y_bcast = Broadcast(y, {1, 5, 7}); diff --git a/tensorflow/compiler/xla/tests/conditional_test.cc b/tensorflow/compiler/xla/tests/conditional_test.cc index b27c1044baf2c0002f166c53a81e4361c60d012a..32cac499c7439af80bafb88ac61b0b078f589599 100644 --- a/tensorflow/compiler/xla/tests/conditional_test.cc +++ b/tensorflow/compiler/xla/tests/conditional_test.cc @@ -359,8 +359,8 @@ XLA_TEST_F(ConditionalOpTest, ReturnTupleOfScalars) { ComputeAndCompareTuple( &builder, - *LiteralUtil::MakeTuple({LiteralUtil::CreateR0(12.0f).get(), - LiteralUtil::CreateR0(25.0f).get()}), + LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR0(12.0f), + LiteralUtil::CreateR0(25.0f)}), {pred_arg.get()}, error_spec_); } @@ -375,12 +375,11 @@ XLA_TEST_F(ConditionalOpTest, ReturnTupleOfArrays) { Conditional(pred, operands, CreateR1TupleCeilComputation(), operands, CreateR1TupleFloorComputation()); - ComputeAndCompareTuple( - &builder, - *LiteralUtil::MakeTuple( - {LiteralUtil::CreateR1({13.0f, 16.0f}).get(), - LiteralUtil::CreateR1({26.0f, 30.0f}).get()}), - {pred_arg.get()}, error_spec_); + ComputeAndCompareTuple(&builder, + LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR1({13.0f, 16.0f}), + LiteralUtil::CreateR1({26.0f, 30.0f})}), + {pred_arg.get()}, error_spec_); } // Test true and false computations that return a tuple of a predicate, a @@ -415,13 +414,12 @@ XLA_TEST_F(ConditionalOpTest, ReturnTupleofPredicateScalarArray) { Conditional(pred, operands, true_builder_result.ConsumeValueOrDie(), operands, false_builder_result.ConsumeValueOrDie()); - ComputeAndCompareTuple( - &builder, - *LiteralUtil::MakeTuple( - {LiteralUtil::CreateR0(true).get(), - LiteralUtil::CreateR0(12.2f).get(), - LiteralUtil::CreateR1({12.8f, 14.6f}).get()}), - {pred_arg.get()}, error_spec_); + ComputeAndCompareTuple(&builder, + LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR0(true), + LiteralUtil::CreateR0(12.2f), + LiteralUtil::CreateR1({12.8f, 14.6f})}), + {pred_arg.get()}, error_spec_); } // Test true and false computations that return a nested tuple. @@ -463,15 +461,13 @@ XLA_TEST_F(ConditionalOpTest, ReturnNestedTuple) { ComputeAndCompareTuple( &builder, - *LiteralUtil::MakeTuple( - {LiteralUtil::MakeTuple( - {LiteralUtil::CreateR0(46.6f).get(), - LiteralUtil::CreateR1({54.4f, 58.4f}).get()}) - .get(), - LiteralUtil::MakeTuple( - {LiteralUtil::CreateR1({62.1f, 67.4f}).get(), - LiteralUtil::CreateR0(9.3f).get()}) - .get()}), + LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR0(46.6f), + LiteralUtil::CreateR1({54.4f, 58.4f})}), + LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR1({62.1f, 67.4f}), + LiteralUtil::CreateR0(9.3f)})}), {pred_arg.get()}, error_spec_); } @@ -633,8 +629,8 @@ XLA_TEST_F(ConditionalOpTest, SwappedInputsInSequentialConditionals) { ComputeAndCompareTuple( &builder, - *LiteralUtil::MakeTuple({LiteralUtil::CreateR0(a).get(), - LiteralUtil::CreateR0(b).get()}), + LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR0(a), LiteralUtil::CreateR0(b)}), {x_arg.get(), y_arg.get()}, error_spec_); }; @@ -642,5 +638,57 @@ XLA_TEST_F(ConditionalOpTest, SwappedInputsInSequentialConditionals) { test_swap(11.24f, 5.55f); } +// Test conditional that duplicates tuple elements in the then and else +// computations. This is a regression test for b/112550242. +XLA_TEST_F(ConditionalOpTest, DuplicateElementsConditional) { + const Shape scalar = ShapeUtil::MakeShape(S32, {}); + const Shape tuple2 = ShapeUtil::MakeTupleShape({scalar, scalar}); + XlaComputation then_comp; + { + XlaBuilder builder(TestName() + ".then"); + auto p = Parameter(&builder, 0, tuple2, "then.p"); + auto e0 = GetTupleElement(p, 0); + auto e1 = GetTupleElement(p, 1); + Tuple(&builder, {e0, e1, e0}); + then_comp = builder.Build().ConsumeValueOrDie(); + } + XlaComputation else_comp; + { + XlaBuilder builder(TestName() + ".else"); + auto p = Parameter(&builder, 0, tuple2, "else.p"); + auto e0 = GetTupleElement(p, 0); + auto e1 = GetTupleElement(p, 1); + Tuple(&builder, {e0, e1, e1}); + else_comp = builder.Build().ConsumeValueOrDie(); + } + + { + // Pred is true case. + std::vector args; + args.push_back( + LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR0(123), + LiteralUtil::CreateR0(-42)})); + args.push_back(LiteralUtil::CreateR0(true)); + XlaBuilder builder(TestName() + ".main"); + auto p = Parameter(&builder, 0, tuple2, "p0"); + auto p_pred = Parameter(&builder, 1, ShapeUtil::MakeShape(PRED, {}), "p1"); + Conditional(p_pred, p, then_comp, p, else_comp); + ComputeAndCompare(&builder, args); + } + { + // Pred is false case. + std::vector args; + args.push_back( + LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR0(123), + LiteralUtil::CreateR0(-42)})); + args.push_back(LiteralUtil::CreateR0(false)); + XlaBuilder builder(TestName() + ".main"); + auto p = Parameter(&builder, 0, tuple2, "p0"); + auto p_pred = Parameter(&builder, 1, ShapeUtil::MakeShape(PRED, {}), "p1"); + Conditional(p_pred, p, then_comp, p, else_comp); + ComputeAndCompare(&builder, args); + } +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/constants_test.cc b/tensorflow/compiler/xla/tests/constants_test.cc index 49375748319ad5fe40db507a034ec4b07adb7e84..72ff1e74a47c8584cb5336c86a1c978c4637a902 100644 --- a/tensorflow/compiler/xla/tests/constants_test.cc +++ b/tensorflow/compiler/xla/tests/constants_test.cc @@ -110,7 +110,7 @@ TEST_F(ConstantsTest, Small_2x2) { TEST_F(ConstantsTest, Empty_3x0x2) { XlaBuilder builder(TestName()); - ConstantLiteral(&builder, *LiteralUtil::CreateR3FromArray3D( + ConstantLiteral(&builder, LiteralUtil::CreateR3FromArray3D( Array3D(3, 0, 2))); ComputeAndCompareR3(&builder, Array3D(3, 0, 2), {}); @@ -126,7 +126,7 @@ TEST_F(ConstantsTest, Small_2x2x2) { {{5.f, 6.f}, // y0 {7.f, 8.f}}, // y1 }); - ConstantLiteral(&builder, *LiteralUtil::CreateR3FromArray3D(array3d)); + ConstantLiteral(&builder, LiteralUtil::CreateR3FromArray3D(array3d)); ComputeAndCompareR3(&builder, array3d, {}); } @@ -140,12 +140,11 @@ TEST_F(ConstantsTest, Small_3x2x1x1) { {5.0f, 4.4f}, // p2 }); input_array.FillWithPZ(pz); - std::unique_ptr input_literal = - LiteralUtil::CreateR4FromArray4D(input_array); + Literal input_literal = LiteralUtil::CreateR4FromArray4D(input_array); { XlaBuilder builder(TestName()); - ConstantLiteral(&builder, *input_literal); + ConstantLiteral(&builder, input_literal); ComputeAndCompareR4(&builder, input_array, {}, error_spec_); } @@ -159,23 +158,21 @@ TEST_F(ConstantsTest, Small_3x2x1x1) { // TODO(b/29263943): Support tuple constants. TEST_F(ConstantsTest, DISABLED_TupleConstant) { XlaBuilder builder(TestName()); - ConstantLiteral(&builder, - *LiteralUtil::MakeTuple( - {LiteralUtil::CreateR2({{1.0}, {2.0}}).get(), - LiteralUtil::CreateR1({2.0, 42}).get()})); + ConstantLiteral(&builder, LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR2({{1.0}, {2.0}}), + LiteralUtil::CreateR1({2.0, 42})})); - std::unique_ptr result = - ExecuteAndTransfer(&builder, {}).ConsumeValueOrDie(); + Literal result = ExecuteAndTransfer(&builder, {}).ConsumeValueOrDie(); LiteralTestUtil::ExpectR2Near({{1.0}, {2.0}}, - LiteralSlice(*result, {0}), error_spec_); - LiteralTestUtil::ExpectR1Near({2.0, 42.0}, LiteralSlice(*result, {1}), + LiteralSlice(result, {0}), error_spec_); + LiteralTestUtil::ExpectR1Near({2.0, 42.0}, LiteralSlice(result, {1}), error_spec_); } TEST_F(ConstantsTest, Token) { XlaBuilder builder(TestName()); - ConstantLiteral(&builder, *LiteralUtil::CreateToken()); + ConstantLiteral(&builder, LiteralUtil::CreateToken()); // TODO(b/80000000): tokens cannot be returned from computations. Tuple(&builder, {}); TF_ASSERT_OK(Execute(&builder, {}).status()); diff --git a/tensorflow/compiler/xla/tests/convert_test.cc b/tensorflow/compiler/xla/tests/convert_test.cc index 7a203d6873dbb5b69f96c50048c2c5ff3150c544..5f063e67847487f1d18bf4ee80b1634ebdf4183a 100644 --- a/tensorflow/compiler/xla/tests/convert_test.cc +++ b/tensorflow/compiler/xla/tests/convert_test.cc @@ -210,10 +210,10 @@ XLA_TEST_F(ConvertTest, ConvertR1S64ToR1F32) { static_cast(0x8000008000000000LL), static_cast(0x8000010000000000LL), }; - std::unique_ptr arg_literal = LiteralUtil::CreateR1({arg}); - auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param"); + Literal arg_literal = LiteralUtil::CreateR1({arg}); + auto arg_param = Parameter(&builder, 0, arg_literal.shape(), "arg_param"); std::unique_ptr arg_data = - client_->TransferToServer(*arg_literal).ConsumeValueOrDie(); + client_->TransferToServer(arg_literal).ConsumeValueOrDie(); ConvertElementType(arg_param, F32); @@ -229,10 +229,10 @@ XLA_TEST_F(ConvertTest, ConvertR1U32ToR1F32) { std::vector arg{0, 1, 0x1000, 0x7fffffff, 0x80000000, 0x80000001, 0x80000002, 0x80000003, 0x80000080, 0x80000081, 0x80000082, 0xFFFFFFFF}; - std::unique_ptr arg_literal = LiteralUtil::CreateR1({arg}); - auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param"); + Literal arg_literal = LiteralUtil::CreateR1({arg}); + auto arg_param = Parameter(&builder, 0, arg_literal.shape(), "arg_param"); std::unique_ptr arg_data = - client_->TransferToServer(*arg_literal).ConsumeValueOrDie(); + client_->TransferToServer(arg_literal).ConsumeValueOrDie(); ConvertElementType(arg_param, F32); @@ -247,10 +247,10 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1U32) { XlaBuilder builder(TestName()); std::vector arg{0.0f, 1.0f, 16777216.0f, 16777218.0f, 2147483647.0f, 4294967040.0f}; - std::unique_ptr arg_literal = LiteralUtil::CreateR1({arg}); - auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param"); + Literal arg_literal = LiteralUtil::CreateR1({arg}); + auto arg_param = Parameter(&builder, 0, arg_literal.shape(), "arg_param"); std::unique_ptr arg_data = - client_->TransferToServer(*arg_literal).ConsumeValueOrDie(); + client_->TransferToServer(arg_literal).ConsumeValueOrDie(); ConvertElementType(arg_param, U32); @@ -264,10 +264,10 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1U32) { XLA_TEST_F(ConvertTest, ConvertR1U32ToR1S64) { XlaBuilder builder(TestName()); std::vector arg{0, 1, 0x1000, 0x7fffffff, 0x80000082, 0xFFFFFFFF}; - std::unique_ptr arg_literal = LiteralUtil::CreateR1({arg}); - auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param"); + Literal arg_literal = LiteralUtil::CreateR1({arg}); + auto arg_param = Parameter(&builder, 0, arg_literal.shape(), "arg_param"); std::unique_ptr arg_data = - client_->TransferToServer(*arg_literal).ConsumeValueOrDie(); + client_->TransferToServer(arg_literal).ConsumeValueOrDie(); ConvertElementType(arg_param, S64); @@ -281,10 +281,10 @@ XLA_TEST_F(ConvertTest, ConvertR1U32ToR1S64) { XLA_TEST_F(ConvertTest, ConvertR1S32ToR1S64) { XlaBuilder builder(TestName()); std::vector arg{0, 1, 0x1000, -1, -0x1000}; - std::unique_ptr arg_literal = LiteralUtil::CreateR1({arg}); - auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param"); + Literal arg_literal = LiteralUtil::CreateR1({arg}); + auto arg_param = Parameter(&builder, 0, arg_literal.shape(), "arg_param"); std::unique_ptr arg_data = - client_->TransferToServer(*arg_literal).ConsumeValueOrDie(); + client_->TransferToServer(arg_literal).ConsumeValueOrDie(); ConvertElementType(arg_param, S64); @@ -318,10 +318,10 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1S64) { 9223370937343148032.f, -9223371487098961920.f, -9223370937343148032.f}; - std::unique_ptr arg_literal = LiteralUtil::CreateR1({arg}); - auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param"); + Literal arg_literal = LiteralUtil::CreateR1({arg}); + auto arg_param = Parameter(&builder, 0, arg_literal.shape(), "arg_param"); std::unique_ptr arg_data = - client_->TransferToServer(*arg_literal).ConsumeValueOrDie(); + client_->TransferToServer(arg_literal).ConsumeValueOrDie(); ConvertElementType(arg_param, S64); @@ -456,7 +456,7 @@ XLA_TEST_F(ConvertTest, ConvertR1F16ToR1F32) { TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr dot_lhs_handle, - client_->TransferToServer(*LiteralUtil::CreateR1(input))); + client_->TransferToServer(LiteralUtil::CreateR1(input))); XlaBuilder builder(TestName()); ConvertElementType( @@ -476,7 +476,7 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1F16) { TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr dot_lhs_handle, - client_->TransferToServer(*LiteralUtil::CreateR1(input))); + client_->TransferToServer(LiteralUtil::CreateR1(input))); XlaBuilder builder(TestName()); ConvertElementType( diff --git a/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc b/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc index 38b6da4fa96b0f6b7ed2d56852eb3ab2872f3520..fd98bf29b8a06d7476d51174b61c6268750db2ec 100644 --- a/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc @@ -93,8 +93,7 @@ XLA_TEST_F(ConvolutionDimensionNumbersTest, auto weight_array = absl::make_unique>(4, 3, 1, 1); weight_array->FillWithMultiples(0.2); auto weight_data = - client_ - ->TransferToServer(*LiteralUtil::CreateR4FromArray4D(*weight_array)) + client_->TransferToServer(LiteralUtil::CreateR4FromArray4D(*weight_array)) .ConsumeValueOrDie(); XlaBuilder builder(TestName()); diff --git a/tensorflow/compiler/xla/tests/convolution_test.cc b/tensorflow/compiler/xla/tests/convolution_test.cc index d2c6478b02423c93860244bc5eb91e652a3eac2e..3aebf784664dac14ba2ea45c5a229b7b2e4fc39d 100644 --- a/tensorflow/compiler/xla/tests/convolution_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_test.cc @@ -91,7 +91,14 @@ class ForwardPassConvolution_3x3x256_256_OutputZ_Iota : public ConvolutionTest { XlaBuilder builder(TestName()); auto lhs = ConstantR4FromArray4D(&builder, *alhs); auto rhs = ConstantR4FromArray4D(&builder, *arhs); - Conv(lhs, rhs, {1, 1}, Padding::kValid); + PrecisionConfig precision; + // The left hand side of the convolution is numbers between 0 and 2304 which + // requires at least 11 mantissa bits and the DEFAULT precision config is + // allowed to round to bfloat16 which only has 7 mantissa bits. + precision.add_operand_precision(PrecisionConfig::HIGHEST); + precision.add_operand_precision(PrecisionConfig::DEFAULT); + Conv(lhs, rhs, {1, 1}, Padding::kValid, /*feature_group_count=*/1, + &precision); ComputeAndCompare(&builder, {}, error_spec_); } @@ -123,8 +130,8 @@ class Convolve_1x1x1x2_1x1x1x2_Valid : public ConvolutionTest { })); ComputeAndCompare(&builder, - {std::move(*LiteralUtil::CreateFromArray(input_data)), - std::move(*LiteralUtil::CreateFromArray(filter_data))}, + {LiteralUtil::CreateFromArray(input_data), + LiteralUtil::CreateFromArray(filter_data)}, error_spec_); } }; @@ -157,8 +164,8 @@ class Convolve_1x1x4x4_1x1x2x2_Valid : public ConvolutionTest { {7.0f, 8.0f}, })); ComputeAndCompare(&builder, - {std::move(*LiteralUtil::CreateFromArray(input_data)), - std::move(*LiteralUtil::CreateFromArray(filter_data))}, + {LiteralUtil::CreateFromArray(input_data), + LiteralUtil::CreateFromArray(filter_data)}, error_spec_); } }; @@ -192,8 +199,8 @@ class Convolve_1x1x4x4_1x1x2x2_Same : public ConvolutionTest { })); ComputeAndCompare(&builder, - {std::move(*LiteralUtil::CreateFromArray(input_data)), - std::move(*LiteralUtil::CreateFromArray(filter_data))}, + {LiteralUtil::CreateFromArray(input_data), + LiteralUtil::CreateFromArray(filter_data)}, error_spec_); } }; @@ -224,8 +231,8 @@ class Convolve_1x1x4x4_1x1x3x3_Same : public ConvolutionTest { {{5.0f, 6.0f, 7.0f}, {8.0f, 9.0f, 10.0f}, {11.0f, 12.0f, 13.0f}})); // clang-format on ComputeAndCompare(&builder, - {std::move(*LiteralUtil::CreateFromArray(input_data)), - std::move(*LiteralUtil::CreateFromArray(filter_data))}, + {LiteralUtil::CreateFromArray(input_data), + LiteralUtil::CreateFromArray(filter_data)}, error_spec_); } }; @@ -249,10 +256,10 @@ XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_Valid) { Array3D expected({{{510, 610, 710, 810}}}); auto input_literal = - client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(input)) + client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input)) .ConsumeValueOrDie(); auto filter_literal = - client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(filter)) + client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter)) .ConsumeValueOrDie(); ComputeAndCompareR3(&builder, expected, @@ -284,10 +291,10 @@ class Convolve1D_1x2x5_1x2x2_WithRHSDilation : public ConvolutionTest { Array3D expected({{{570.0f, 670.0f, 770.0f}}}); auto input_literal = - client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(input)) + client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input)) .ConsumeValueOrDie(); auto filter_literal = - client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(filter)) + client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter)) .ConsumeValueOrDie(); ComputeAndCompareR3(&builder, expected, @@ -319,10 +326,10 @@ XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_WithLHSDilation) { Array3D expected({{{190, 320, 230, 380, 270, 440, 310, 500}}}); auto input_literal = - client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(input)) + client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input)) .ConsumeValueOrDie(); auto filter_literal = - client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(filter)) + client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter)) .ConsumeValueOrDie(); ComputeAndCompareR3(&builder, expected, @@ -350,10 +357,10 @@ XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_WithLHSAndRHSDilation) { Array3D expected({{{510, 0, 610, 0, 710, 0, 810}}}); auto input_literal = - client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(input)) + client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input)) .ConsumeValueOrDie(); auto filter_literal = - client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(filter)) + client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter)) .ConsumeValueOrDie(); ComputeAndCompareR3(&builder, expected, @@ -386,10 +393,10 @@ class Convolve1D_1x2x5_1x2x2_WithPadding : public ConvolutionTest { {{{0.0f, 260.0f, 510.0f, 610.0f, 710.0f, 810.0f, 350.0f, 0.0f}}}); auto input_literal = - client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(input)) + client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input)) .ConsumeValueOrDie(); auto filter_literal = - client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(filter)) + client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter)) .ConsumeValueOrDie(); ComputeAndCompareR3(&builder, expected, @@ -435,23 +442,23 @@ XLA_TEST_F(ConvolutionTest, Convolve3D_1x4x2x3x3_2x2x2x3x3_Valid) { std::vector input_elems(ShapeUtil::ElementsIn(input_shape)); iota(input_elems.begin(), input_elems.end(), 1.0f); auto input_r1 = LiteralUtil::CreateR1(input_elems); - auto input_r5 = input_r1->Reshape(input_dims).ConsumeValueOrDie(); + auto input_r5 = input_r1.Reshape(input_dims).ConsumeValueOrDie(); std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape)); iota(filter_elems.begin(), filter_elems.end(), 1.0f); auto filter_r1 = LiteralUtil::CreateR1(filter_elems); - auto filter_r5 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie(); + auto filter_r5 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie(); auto expected_r1 = LiteralUtil::CreateR1( {19554, 19962, 20370, 22110, 22590, 23070, 34890, 35730, 36570, 37446, 38358, 39270, 50226, 51498, 52770, 52782, 54126, 55470}); - auto expected_r5 = expected_r1->Reshape({1, 3, 1, 2, 3}).ConsumeValueOrDie(); + auto expected_r5 = expected_r1.Reshape({1, 3, 1, 2, 3}).ConsumeValueOrDie(); - auto input_literal = client_->TransferToServer(*input_r5).ConsumeValueOrDie(); + auto input_literal = client_->TransferToServer(input_r5).ConsumeValueOrDie(); auto filter_literal = - client_->TransferToServer(*filter_r5).ConsumeValueOrDie(); + client_->TransferToServer(filter_r5).ConsumeValueOrDie(); - ComputeAndCompareLiteral(&builder, *expected_r5, + ComputeAndCompareLiteral(&builder, expected_r5, {input_literal.get(), filter_literal.get()}, error_spec_); } @@ -498,23 +505,23 @@ class Convolve2D_1x3x3x5_3x3x5x3_Valid : public ConvolutionTest { std::vector input_elems(ShapeUtil::ElementsIn(input_shape)); iota_int_init_value(input_elems, 1); auto input_r1 = LiteralUtil::CreateR1(input_elems); - auto input_r4 = input_r1->Reshape(input_dims).ConsumeValueOrDie(); + auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie(); std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape)); iota_int_init_value(filter_elems, 1); auto filter_r1 = LiteralUtil::CreateR1(filter_elems); - auto filter_r4 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie(); + auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie(); auto expected_r1 = LiteralUtil::CreateR1( {static_cast(92115), static_cast(93150), static_cast(94185)}); - auto expected_r4 = expected_r1->Reshape({1, 1, 1, 3}).ConsumeValueOrDie(); + auto expected_r4 = expected_r1.Reshape({1, 1, 1, 3}).ConsumeValueOrDie(); auto input_literal = - client_->TransferToServer(*input_r4).ConsumeValueOrDie(); + client_->TransferToServer(input_r4).ConsumeValueOrDie(); auto filter_literal = - client_->TransferToServer(*filter_r4).ConsumeValueOrDie(); + client_->TransferToServer(filter_r4).ConsumeValueOrDie(); - ComputeAndCompareLiteral(&builder, *expected_r4, + ComputeAndCompareLiteral(&builder, expected_r4, {input_literal.get(), filter_literal.get()}, error_spec_); } @@ -558,12 +565,12 @@ class Convolve2D_1x3x3x5_3x3x1x15_Depthwise_Valid : public ConvolutionTest { std::vector input_elems(ShapeUtil::ElementsIn(input_shape)); iota_int_init_value(input_elems, 1); auto input_r1 = LiteralUtil::CreateR1(input_elems); - auto input_r4 = input_r1->Reshape(input_dims).ConsumeValueOrDie(); + auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie(); std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape)); iota_int_init_value(filter_elems, 1); auto filter_r1 = LiteralUtil::CreateR1(filter_elems); - auto filter_r4 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie(); + auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie(); auto expected_r1 = LiteralUtil::CreateR1( {static_cast(16029), static_cast(16218), static_cast(16407), @@ -571,14 +578,14 @@ class Convolve2D_1x3x3x5_3x3x1x15_Depthwise_Valid : public ConvolutionTest { static_cast(18369), static_cast(18576), static_cast(18783), static_cast(19620), static_cast(19836), static_cast(20052), static_cast(20925), static_cast(21150), static_cast(21375)}); - auto expected_r4 = expected_r1->Reshape({1, 1, 1, 15}).ConsumeValueOrDie(); + auto expected_r4 = expected_r1.Reshape({1, 1, 1, 15}).ConsumeValueOrDie(); auto input_literal = - client_->TransferToServer(*input_r4).ConsumeValueOrDie(); + client_->TransferToServer(input_r4).ConsumeValueOrDie(); auto filter_literal = - client_->TransferToServer(*filter_r4).ConsumeValueOrDie(); + client_->TransferToServer(filter_r4).ConsumeValueOrDie(); - ComputeAndCompareLiteral(&builder, *expected_r4, + ComputeAndCompareLiteral(&builder, expected_r4, {input_literal.get(), filter_literal.get()}, error_spec_); } @@ -624,26 +631,26 @@ class Convolve2D_1x2x2x6_2x2x1x12_Grouped_Valid : public ConvolutionTest { std::vector input_elems(ShapeUtil::ElementsIn(input_shape)); iota_int_init_value(input_elems, 1); auto input_r1 = LiteralUtil::CreateR1(input_elems); - auto input_r4 = input_r1->Reshape(input_dims).ConsumeValueOrDie(); + auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie(); std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape)); iota_int_init_value(filter_elems, 1); auto filter_r1 = LiteralUtil::CreateR1(filter_elems); - auto filter_r4 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie(); + auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie(); auto expected_r1 = LiteralUtil::CreateR1( {static_cast(5076), static_cast(5160), static_cast(5244), static_cast(5328), static_cast(6164), static_cast(6264), static_cast(6364), static_cast(6464), static_cast(7380), static_cast(7496), static_cast(7612), static_cast(7728)}); - auto expected_r4 = expected_r1->Reshape({1, 1, 1, 12}).ConsumeValueOrDie(); + auto expected_r4 = expected_r1.Reshape({1, 1, 1, 12}).ConsumeValueOrDie(); auto input_literal = - client_->TransferToServer(*input_r4).ConsumeValueOrDie(); + client_->TransferToServer(input_r4).ConsumeValueOrDie(); auto filter_literal = - client_->TransferToServer(*filter_r4).ConsumeValueOrDie(); + client_->TransferToServer(filter_r4).ConsumeValueOrDie(); - ComputeAndCompareLiteral(&builder, *expected_r4, + ComputeAndCompareLiteral(&builder, expected_r4, {input_literal.get(), filter_literal.get()}, error_spec_); } @@ -692,8 +699,8 @@ XLA_TEST_P(ConvolveWithAndWithoutCanonicalization, expected_result.Fill(0); ComputeAndCompare(&builder, - {std::move(*LiteralUtil::CreateFromArray(param0)), - std::move(*LiteralUtil::CreateFromArray(param1))}, + {LiteralUtil::CreateFromArray(param0), + LiteralUtil::CreateFromArray(param1)}, error_spec_); } @@ -749,26 +756,25 @@ class Convolve1D1WindowTestBase std::vector input_elems(ShapeUtil::ElementsIn(input_shape), static_cast(1.0f)); auto input_r1 = LiteralUtil::CreateR1(input_elems); - auto input_r3 = input_r1->Reshape(input_dims).ConsumeValueOrDie(); + auto input_r3 = input_r1.Reshape(input_dims).ConsumeValueOrDie(); std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape), static_cast(1.0f)); auto filter_r1 = LiteralUtil::CreateR1(filter_elems); - auto filter_r3 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie(); + auto filter_r3 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie(); std::vector expect_elems(batch * output_feature * num_windows, static_cast(window_size * input_feature)); auto expected_r1 = LiteralUtil::CreateR1(expect_elems); - auto expected_r3 = - expected_r1->Reshape({batch, num_windows, output_feature}) - .ConsumeValueOrDie(); + auto expected_r3 = expected_r1.Reshape({batch, num_windows, output_feature}) + .ConsumeValueOrDie(); auto input_literal = - client_->TransferToServer(*input_r3).ConsumeValueOrDie(); + client_->TransferToServer(input_r3).ConsumeValueOrDie(); auto filter_literal = - client_->TransferToServer(*filter_r3).ConsumeValueOrDie(); - ComputeAndCompareLiteral(&builder, *expected_r3, + client_->TransferToServer(filter_r3).ConsumeValueOrDie(); + ComputeAndCompareLiteral(&builder, expected_r3, {input_literal.get(), filter_literal.get()}, error_spec_); } @@ -868,8 +874,8 @@ XLA_TEST_F(ConvolutionTest, Convolve_bf16_1x1x1x2_1x1x1x2_Valid) { })); ComputeAndCompare(&builder, - {std::move(*LiteralUtil::CreateFromArray(input_data)), - std::move(*LiteralUtil::CreateFromArray(filter_data))}, + {LiteralUtil::CreateFromArray(input_data), + LiteralUtil::CreateFromArray(filter_data)}, error_spec_); } @@ -877,7 +883,7 @@ XLA_TEST_F(ConvolutionTest, Convolve_bf16_1x1x1x2_1x1x1x2_Valid) { // (We run this test on all platforms, because, what the heck.) XLA_TEST_F(ConvolutionTest, NoCudnnAlgorithmPicker) { execution_options_.mutable_debug_options()->add_xla_disable_hlo_passes( - "cudnn-convolution-algorithm-picker"); + "cudnn-conv-algorithm-picker"); XlaBuilder builder(TestName()); Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2}); @@ -891,9 +897,44 @@ XLA_TEST_F(ConvolutionTest, NoCudnnAlgorithmPicker) { Array4D filter_data(1, 1, 1, 2); filter_data.FillIota(10); - ComputeAndCompare(&builder, - {std::move(*LiteralUtil::CreateFromArray(input_data)), - std::move(*LiteralUtil::CreateFromArray(filter_data))}); + ComputeAndCompare(&builder, {LiteralUtil::CreateFromArray(input_data), + LiteralUtil::CreateFromArray(filter_data)}); +} + +XLA_TEST_F(ConvolutionTest, ConvolveF32BackwardInputGroupedConvolution) { + XlaBuilder builder(TestName()); + Shape input_shape = ShapeUtil::MakeShape(F32, {1, 64, 100, 100}); + Array4D input_data(1, 64, 100, 100); + input_data.FillRandom(/*value=*/0.023, 0.001, /*seed=*/45321); + Shape filter_shape = ShapeUtil::MakeShape(F32, {7, 7, 1, 64}); + Array4D filter_data(7, 7, 1, 64); + input_data.FillRandom(/*value=*/0.023, 0.001, /*seed=*/45320); + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = ConstantR4FromArray4D(&builder, filter_data); + + // Specify bf01_01io->bf01 as dimension numbers. + ConvolutionDimensionNumbers dnums; + // Input + dnums.set_input_feature_dimension(1); + dnums.set_input_batch_dimension(0); + dnums.add_input_spatial_dimensions(2); + dnums.add_input_spatial_dimensions(3); + // Kernel + dnums.set_kernel_input_feature_dimension(2); + dnums.set_kernel_output_feature_dimension(3); + dnums.add_kernel_spatial_dimensions(0); + dnums.add_kernel_spatial_dimensions(1); + // Output + dnums.set_output_batch_dimension(0); + dnums.set_output_feature_dimension(1); + dnums.add_output_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(3); + ConvGeneral(input, filter, /*window_strides=*/{1, 1}, + /*padding=*/{{3, 3}, {3, 3}}, /*dimension_numbers=*/dnums, + /*feature_group_count=*/64); + + ComputeAndCompare(&builder, {LiteralUtil::CreateFromArray(input_data)}, + error_spec_); } class ConvolutionHloTest : public HloTestBase {}; diff --git a/tensorflow/compiler/xla/tests/convolution_variants_test.cc b/tensorflow/compiler/xla/tests/convolution_variants_test.cc index 6784c16715da72d337edf70fa51db42c59404136..ba3e9c436e3cfa574a07e881a187ff4c7d6243a1 100644 --- a/tensorflow/compiler/xla/tests/convolution_variants_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_variants_test.cc @@ -1335,23 +1335,23 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardInputEvenPadding3D) { auto gradients_flat = LiteralUtil::CreateR1({1}); auto gradients_literal = - gradients_flat->Reshape({1, 1, 1, 1, 1}).ConsumeValueOrDie(); - auto gradients = ConstantLiteral(&builder, *gradients_literal); + gradients_flat.Reshape({1, 1, 1, 1, 1}).ConsumeValueOrDie(); + auto gradients = ConstantLiteral(&builder, gradients_literal); auto weights_flat = LiteralUtil::CreateR1({1, 10, 100}); auto weights_literal = - weights_flat->Reshape({1, 1, 1, 1, 3}).ConsumeValueOrDie(); - auto weights = ConstantLiteral(&builder, *weights_literal); + weights_flat.Reshape({1, 1, 1, 1, 3}).ConsumeValueOrDie(); + auto weights = ConstantLiteral(&builder, weights_literal); auto expected_flat = LiteralUtil::CreateR1({10}); auto expected_literal = - expected_flat->Reshape({1, 1, 1, 1, 1}).ConsumeValueOrDie(); + expected_flat.Reshape({1, 1, 1, 1, 1}).ConsumeValueOrDie(); auto mirrored_weights = Rev(weights, {2, 3, 4}); ConvWithGeneralPadding(gradients, mirrored_weights, /*window_strides=*/{1, 1, 1}, /*padding=*/{{0, 0}, {0, 0}, {1, 1}}); - ComputeAndCompareLiteral(&builder, *expected_literal, {}, error_spec_); + ComputeAndCompareLiteral(&builder, expected_literal, {}, error_spec_); } XLA_TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding3D) { @@ -1359,17 +1359,17 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding3D) { auto activations_flat = LiteralUtil::CreateR1({1, 2, 3, 4}); auto activations_literal = - activations_flat->Reshape({1, 1, 1, 1, 4}).ConsumeValueOrDie(); - auto activations = ConstantLiteral(&builder, *activations_literal); + activations_flat.Reshape({1, 1, 1, 1, 4}).ConsumeValueOrDie(); + auto activations = ConstantLiteral(&builder, activations_literal); auto gradients_flat = LiteralUtil::CreateR1({100, 10, 1}); auto gradients_literal = - gradients_flat->Reshape({1, 1, 1, 1, 3}).ConsumeValueOrDie(); - auto gradients = ConstantLiteral(&builder, *gradients_literal); + gradients_flat.Reshape({1, 1, 1, 1, 3}).ConsumeValueOrDie(); + auto gradients = ConstantLiteral(&builder, gradients_literal); auto expected_flat = LiteralUtil::CreateR1({13, 24, 130}); auto expected_literal = - expected_flat->Reshape({1, 1, 1, 1, 3}).ConsumeValueOrDie(); + expected_flat.Reshape({1, 1, 1, 1, 3}).ConsumeValueOrDie(); auto forward_conv = ConvGeneralDilated(activations, gradients, @@ -1379,7 +1379,7 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding3D) { XlaBuilder::CreateDefaultConvDimensionNumbers( /*num_spatial_dims=*/3)); Transpose(forward_conv, {0, 1, 2, 3, 4}); - ComputeAndCompareLiteral(&builder, *expected_literal, {}, error_spec_); + ComputeAndCompareLiteral(&builder, expected_literal, {}, error_spec_); } } // namespace diff --git a/tensorflow/compiler/xla/tests/copy_test.cc b/tensorflow/compiler/xla/tests/copy_test.cc index 50a9ebc1e9915d5e8ad8d02276987784fe30b8fc..1407e68d9a336b6bb1c960711015430f872aa912 100644 --- a/tensorflow/compiler/xla/tests/copy_test.cc +++ b/tensorflow/compiler/xla/tests/copy_test.cc @@ -40,49 +40,48 @@ class CopyOpTest : public HloTestBase { protected: void TestCopyOp(const Literal& literal) { auto builder = HloComputation::Builder(TestName()); - auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(literal.CloneToUnique())); + auto constant = + builder.AddInstruction(HloInstruction::CreateConstant(literal.Clone())); builder.AddInstruction(HloInstruction::CreateUnary( constant->shape(), HloOpcode::kCopy, constant)); auto computation = builder.Build(); auto module = CreateNewModule(); module->AddEntryComputation(std::move(computation)); - std::unique_ptr result = ExecuteAndTransfer(std::move(module), {}); - EXPECT_TRUE(LiteralTestUtil::Equal(literal, *result)); + Literal result = ExecuteAndTransfer(std::move(module), {}); + EXPECT_TRUE(LiteralTestUtil::Equal(literal, result)); } void TestCopyConstantLayout021(size_t n1, size_t n2, size_t n3); void TestCopyConstantLayoutR4(size_t n1, size_t n2, size_t n3, size_t n4, - tensorflow::gtl::ArraySlice permutation); + absl::Span permutation); }; XLA_TEST_F(CopyOpTest, CopyR0Bool) { - TestCopyOp(*LiteralUtil::CreateR0(true)); + TestCopyOp(LiteralUtil::CreateR0(true)); } XLA_TEST_F(CopyOpTest, CopyR1S0U32) { - TestCopyOp(*LiteralUtil::CreateR1({})); + TestCopyOp(LiteralUtil::CreateR1({})); } XLA_TEST_F(CopyOpTest, CopyR1S3U32) { - TestCopyOp(*LiteralUtil::CreateR1({1, 2, 3})); + TestCopyOp(LiteralUtil::CreateR1({1, 2, 3})); } XLA_TEST_F(CopyOpTest, CopyR3F32_2x2x3) { - TestCopyOp( - *LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, - {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}})); + TestCopyOp(LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, + {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}})); } XLA_TEST_F(CopyOpTest, CopyR4S32_2x2x3x2) { - TestCopyOp(*LiteralUtil::CreateR4( + TestCopyOp(LiteralUtil::CreateR4( {{{{1, -2}, {-4, 5}, {6, 7}}, {{8, 9}, {10, 11}, {12, 13}}}, {{{10, 3}, {7, -2}, {3, 6}}, {{2, 5}, {-11, 5}, {-2, -5}}}})); } XLA_TEST_F(CopyOpTest, CopyR4S32_0x2x3x2) { - TestCopyOp(*LiteralUtil::CreateR4FromArray4D(Array4D(0, 2, 3, 2))); + TestCopyOp(LiteralUtil::CreateR4FromArray4D(Array4D(0, 2, 3, 2))); } XLA_TEST_F(CopyOpTest, CopyParameterScalar) { @@ -90,7 +89,7 @@ XLA_TEST_F(CopyOpTest, CopyParameterScalar) { // Copy literal to device to use as parameter. auto literal = LiteralUtil::CreateR0(42.0); - Shape shape = literal->shape(); + Shape shape = literal.shape(); auto param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, shape, "param0")); @@ -102,9 +101,8 @@ XLA_TEST_F(CopyOpTest, CopyParameterScalar) { auto module = CreateNewModule(); module->AddEntryComputation(std::move(computation)); - std::unique_ptr result = - ExecuteAndTransfer(std::move(module), {literal.get()}); - LiteralTestUtil::ExpectR0Near(42.0f, *result, error_spec_); + Literal result = ExecuteAndTransfer(std::move(module), {&literal}); + LiteralTestUtil::ExpectR0Near(42.0f, result, error_spec_); } XLA_TEST_F(CopyOpTest, CopyConstantR2Twice) { @@ -123,19 +121,17 @@ XLA_TEST_F(CopyOpTest, CopyConstantR2Twice) { auto module = CreateNewModule(); module->AddEntryComputation(std::move(computation)); - std::unique_ptr result = ExecuteAndTransfer(std::move(module), {}); - LiteralTestUtil::ExpectR2Near({{1.0, 2.0}, {3.0, 4.0}}, *result, + Literal result = ExecuteAndTransfer(std::move(module), {}); + LiteralTestUtil::ExpectR2Near({{1.0, 2.0}, {3.0, 4.0}}, result, error_spec_); } XLA_TEST_F(CopyOpTest, CopyConstantR2DifferentLayouts) { HloComputation::Builder builder(TestName()); - std::unique_ptr literal = - LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + Literal literal = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); // Reverse the minor-to-major order of the literal. - Layout* literal_layout = - literal->mutable_shape_do_not_use()->mutable_layout(); + Layout* literal_layout = literal.mutable_shape_do_not_use()->mutable_layout(); ASSERT_EQ(2, literal_layout->minor_to_major_size()); literal_layout->mutable_minor_to_major()->SwapElements(0, 1); @@ -149,11 +145,11 @@ XLA_TEST_F(CopyOpTest, CopyConstantR2DifferentLayouts) { auto module = CreateNewModule(); module->AddEntryComputation(std::move(computation)); - std::unique_ptr result = ExecuteAndTransfer(std::move(module), {}); + Literal result = ExecuteAndTransfer(std::move(module), {}); // The result of the computation has the default layout, which is the inverse // of the layout of the source literal. - LiteralTestUtil::ExpectR2Near({{1.0, 3.0}, {2.0, 4.0}}, *result, + LiteralTestUtil::ExpectR2Near({{1.0, 3.0}, {2.0, 4.0}}, result, error_spec_); } @@ -169,7 +165,7 @@ void CopyOpTest::TestCopyConstantLayout021(size_t n1, size_t n2, size_t n3) { HloComputation::Builder builder(TestName()); - std::unique_ptr literal = LiteralUtil::CreateR3FromArray3D(a); + Literal literal = LiteralUtil::CreateR3FromArray3D(a); HloInstruction* constant = builder.AddInstruction( HloInstruction::CreateConstant(std::move(literal))); @@ -182,14 +178,14 @@ void CopyOpTest::TestCopyConstantLayout021(size_t n1, size_t n2, size_t n3) { auto module = CreateNewModule(); module->AddEntryComputation(std::move(computation)); ForceResultLayout(module.get(), LayoutUtil::MakeLayout({1, 2, 0})); - std::unique_ptr result = ExecuteAndTransfer(std::move(module), {}); + Literal result = ExecuteAndTransfer(std::move(module), {}); - LiteralTestUtil::ExpectR3EqualArray3D(a, *result); + LiteralTestUtil::ExpectR3EqualArray3D(a, result); } -void CopyOpTest::TestCopyConstantLayoutR4( - size_t n1, size_t n2, size_t n3, size_t n4, - tensorflow::gtl::ArraySlice permutation) { +void CopyOpTest::TestCopyConstantLayoutR4(size_t n1, size_t n2, size_t n3, + size_t n4, + absl::Span permutation) { Array4D a(n1, n2, n3, n4); for (size_t i = 0; i < n1; ++i) { for (size_t j = 0; j < n2; ++j) { @@ -203,7 +199,7 @@ void CopyOpTest::TestCopyConstantLayoutR4( HloComputation::Builder builder(TestName()); - std::unique_ptr literal = LiteralUtil::CreateR4FromArray4D(a); + Literal literal = LiteralUtil::CreateR4FromArray4D(a); HloInstruction* constant = builder.AddInstruction( HloInstruction::CreateConstant(std::move(literal))); @@ -216,9 +212,9 @@ void CopyOpTest::TestCopyConstantLayoutR4( auto module = CreateNewModule(); module->AddEntryComputation(std::move(computation)); ForceResultLayout(module.get(), LayoutUtil::MakeLayout(permutation)); - std::unique_ptr result = ExecuteAndTransfer(std::move(module), {}); + Literal result = ExecuteAndTransfer(std::move(module), {}); - LiteralTestUtil::ExpectR4EqualArray4D(a, *result); + LiteralTestUtil::ExpectR4EqualArray4D(a, result); } XLA_TEST_F(CopyOpTest, CopyConstantR3Layout021_SingleIncompleteTilePerLayer) { @@ -250,11 +246,11 @@ XLA_TEST_F(CopyOpClientTest, Copy0x0) { XlaBuilder builder(TestName()); Parameter(&builder, 0, in_shape, "input"); - auto input_data = client_->TransferToServer(*empty).ConsumeValueOrDie(); + auto input_data = client_->TransferToServer(empty).ConsumeValueOrDie(); auto actual = ExecuteAndTransfer(&builder, {input_data.get()}, &out_shape) .ConsumeValueOrDie(); - EXPECT_TRUE(LiteralTestUtil::Equal(*empty, *actual)); + EXPECT_TRUE(LiteralTestUtil::Equal(empty, actual)); } } // namespace diff --git a/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc b/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc index d12a4e7fcd7813775a81677bcaa07af60ff9b477..410732c07b7b6d3ece33ab11f4778241dc53ca50 100644 --- a/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc +++ b/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc @@ -46,7 +46,7 @@ XLA_TEST_F(TrivialCrossReplicaSumTest, OneOperand) { auto module = ParseHloString(module_str, GetModuleConfigForTest()).ValueOrDie(); auto literal = LiteralUtil::CreateR1({1, 2, 3}); - EXPECT_EQ(*literal, *ExecuteAndTransfer(std::move(module), {literal.get()})); + EXPECT_EQ(literal, ExecuteAndTransfer(std::move(module), {&literal})); } XLA_TEST_F(TrivialCrossReplicaSumTest, MultipleOperands) { @@ -68,9 +68,8 @@ XLA_TEST_F(TrivialCrossReplicaSumTest, MultipleOperands) { ParseHloString(module_str, GetModuleConfigForTest()).ValueOrDie(); auto literal0 = LiteralUtil::CreateR1({1, 2, 3}); auto literal1 = LiteralUtil::CreateR1({10, 20}); - EXPECT_EQ( - *LiteralUtil::MakeTuple({literal0.get(), literal1.get()}), - *ExecuteAndTransfer(std::move(module), {literal0.get(), literal1.get()})); + EXPECT_EQ(LiteralUtil::MakeTuple({&literal0, &literal1}), + ExecuteAndTransfer(std::move(module), {&literal0, &literal1})); } // On the GPU backend, constants get special handling. Someone might pass a @@ -95,8 +94,8 @@ XLA_TEST_F(TrivialCrossReplicaSumTest, ConstantOperand) { ParseHloString(module_str, GetModuleConfigForTest()).ValueOrDie(); auto literal0 = LiteralUtil::CreateR1({1, 2, 3}); auto literal1 = LiteralUtil::CreateR1({10, 20}); - EXPECT_EQ(*LiteralUtil::MakeTuple({literal0.get(), literal1.get()}), - *ExecuteAndTransfer(std::move(module), {literal0.get()})); + EXPECT_EQ(LiteralUtil::MakeTuple({&literal0, &literal1}), + ExecuteAndTransfer(std::move(module), {&literal0})); } } // namespace diff --git a/tensorflow/compiler/xla/tests/custom_call_test.cc b/tensorflow/compiler/xla/tests/custom_call_test.cc index 6f7fc0e6e52a69387a4c491871b6fcd97ac638b6..001490c6a8c568656437465054ee4db40d0d8dee 100644 --- a/tensorflow/compiler/xla/tests/custom_call_test.cc +++ b/tensorflow/compiler/xla/tests/custom_call_test.cc @@ -80,8 +80,8 @@ XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR0F32Add2)) { module->AddEntryComputation(builder.Build()); - std::unique_ptr result = ExecuteAndTransfer(std::move(module), {}); - LiteralTestUtil::ExpectR0Near(44.0f, *result, error_spec_); + Literal result = ExecuteAndTransfer(std::move(module), {}); + LiteralTestUtil::ExpectR0Near(44.0f, result, error_spec_); } XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR2F32Reduce)) { @@ -101,12 +101,11 @@ XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR2F32Reduce)) { module->AddEntryComputation(builder.Build()); - std::unique_ptr result = ExecuteAndTransfer(std::move(module), {}); - LiteralTestUtil::ExpectR0Near(10.0f, *result, error_spec_); + Literal result = ExecuteAndTransfer(std::move(module), {}); + LiteralTestUtil::ExpectR0Near(10.0f, result, error_spec_); } -XLA_TEST_F(CustomCallTest, - DISABLED_ON_GPU(CustomCall_UsedInOtherComputations)) { +XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(UsedInOtherComputations)) { auto module = CreateNewModule(); auto b = HloComputation::Builder(TestName()); @@ -125,9 +124,56 @@ XLA_TEST_F(CustomCallTest, module->AddEntryComputation(b.Build()); - std::unique_ptr result = ExecuteAndTransfer(std::move(module), {}); + Literal result = ExecuteAndTransfer(std::move(module), {}); LiteralTestUtil::ExpectR3EqualArray3D( - Array3D{{{2, 3}, {4, 5}}, {{3, 4}, {5, 6}}}, *result); + Array3D{{{2, 3}, {4, 5}}, {{3, 4}, {5, 6}}}, result); +} + +XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(InputAndOutputLayoutDiffer)) { + auto module = CreateNewModule(); + auto b = HloComputation::Builder(TestName()); + + auto input = + b.AddInstruction(HloInstruction::CreateParameter(0, r2f32_, "p")); + b.AddInstruction( + HloInstruction::CreateCustomCall(r2f32_, {input}, "Add1ToValues")); + + module->AddEntryComputation(b.Build()); + ForceParameterLayout(module.get(), 0, LayoutUtil::MakeLayout({1, 0})); + ForceResultLayout(module.get(), LayoutUtil::MakeLayout({0, 1})); + + Literal argument = LiteralUtil::CreateR2({{1.f, 2.f}, {3.f, 4.f}}); + + // Note, the expected result is transposed! This is because the input and + // output layouts of the custom call differ and the called function just + // blindly adds one to each element. + Literal result = ExecuteAndTransfer(std::move(module), {&argument}); + LiteralTestUtil::ExpectR2Equal({{2.f, 4.f}, {3.f, 5.f}}, result); +} + +XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(LayoutConstrained)) { + // The argument and result of the computation are set to different layouts, + // but the custom call is layout constrained to a fixed operand and result + // layout, so the correct result should be produced. + auto module = CreateNewModule(); + auto b = HloComputation::Builder(TestName()); + + auto input = + b.AddInstruction(HloInstruction::CreateParameter(0, r2f32_, "p")); + + const Shape& r2f32_dim0_major = + ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0}); + b.AddInstruction(HloInstruction::CreateCustomCall( + r2f32_dim0_major, {input}, "Add1ToValues", {r2f32_dim0_major})); + + module->AddEntryComputation(b.Build()); + ForceParameterLayout(module.get(), 0, LayoutUtil::MakeLayout({1, 0})); + ForceResultLayout(module.get(), LayoutUtil::MakeLayout({0, 1})); + + Literal argument = LiteralUtil::CreateR2({{1.f, 2.f}, {3.f, 4.f}}); + + Literal result = ExecuteAndTransfer(std::move(module), {&argument}); + LiteralTestUtil::ExpectR2Equal({{2.f, 3.f}, {4.f, 5.f}}, result); } class CustomCallClientAPITest : public ClientLibraryTestBase {}; diff --git a/tensorflow/compiler/xla/tests/deallocation_test.cc b/tensorflow/compiler/xla/tests/deallocation_test.cc index 5f234f36a8543ad408fb3430b27844beb16a54b5..86fd1ceb1368feedb14088fa7045224440f6c4f9 100644 --- a/tensorflow/compiler/xla/tests/deallocation_test.cc +++ b/tensorflow/compiler/xla/tests/deallocation_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_builder.h" @@ -24,7 +25,6 @@ limitations under the License. #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/test_macros.h" -#include "tensorflow/core/lib/gtl/array_slice.h" namespace xla { namespace { @@ -36,7 +36,7 @@ class DeallocationTest : public ClientLibraryTestBase { // Build and execute the given computation then verify the results can be // transferred from the device successfully. std::unique_ptr ExecuteAndCheckTransfer( - XlaBuilder* builder, tensorflow::gtl::ArraySlice arguments) { + XlaBuilder* builder, absl::Span arguments) { XlaComputation computation = builder->Build().ConsumeValueOrDie(); auto global_data = client_->Execute(computation, arguments, &execution_options_) diff --git a/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc b/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc index 2db6503afab748d7b778e26b2f9350ac64c7778b..e0f23b0fa807ca27038afa2eec5f739508e3d5bd 100644 --- a/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc +++ b/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_builder.h" @@ -28,7 +29,6 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/test.h" namespace xla { @@ -42,7 +42,7 @@ class DeconstructTupleTest : public ClientLibraryTestBase { // Build and execute the given computation then verify the results can be // transferred from the device successfully. std::unique_ptr ExecuteAndCheckTransfer( - XlaBuilder* builder, tensorflow::gtl::ArraySlice arguments) { + XlaBuilder* builder, absl::Span arguments) { XlaComputation computation = builder->Build().ConsumeValueOrDie(); auto global_data = client_->Execute(computation, arguments, &execution_options_) @@ -64,11 +64,11 @@ TEST_F(DeconstructTupleTest, DeconstructTuple) { // Try copying the elements back and comparing it auto handles = result_status.ConsumeValueOrDie(); - std::unique_ptr literal; + Literal literal; TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[0])); - LiteralTestUtil::ExpectR1Equal({1.0, 2.0, 3.0, 4.0}, *literal); + LiteralTestUtil::ExpectR1Equal({1.0, 2.0, 3.0, 4.0}, literal); TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[1])); - LiteralTestUtil::ExpectR1Equal({2.0, 4.0, 6.0, 8.0}, *literal); + LiteralTestUtil::ExpectR1Equal({2.0, 4.0, 6.0, 8.0}, literal); } TEST_F(DeconstructTupleTest, DeconstructTupleTwice) { @@ -86,19 +86,19 @@ TEST_F(DeconstructTupleTest, DeconstructTupleTwice) { auto handles1 = result_status1.ConsumeValueOrDie(); auto handles2 = result_status2.ConsumeValueOrDie(); - std::unique_ptr literal; + Literal literal; TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles1[0])); - LiteralTestUtil::ExpectR1Equal({1.0, 2.0, 3.0, 4.0}, *literal); + LiteralTestUtil::ExpectR1Equal({1.0, 2.0, 3.0, 4.0}, literal); TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles1[1])); - LiteralTestUtil::ExpectR1Equal({2.0, 4.0, 6.0, 8.0}, *literal); + LiteralTestUtil::ExpectR1Equal({2.0, 4.0, 6.0, 8.0}, literal); handles1[0].reset(); handles1[1].reset(); TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles2[0])); - LiteralTestUtil::ExpectR1Equal({1.0, 2.0, 3.0, 4.0}, *literal); + LiteralTestUtil::ExpectR1Equal({1.0, 2.0, 3.0, 4.0}, literal); TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles2[1])); - LiteralTestUtil::ExpectR1Equal({2.0, 4.0, 6.0, 8.0}, *literal); + LiteralTestUtil::ExpectR1Equal({2.0, 4.0, 6.0, 8.0}, literal); } XLA_TEST_F(DeconstructTupleTest, DeconstructTupleRepeatedElement) { @@ -116,15 +116,15 @@ XLA_TEST_F(DeconstructTupleTest, DeconstructTupleRepeatedElement) { // the same as handle[3] and handle[1] should be the same as handle[2]. auto handles = result_status.ConsumeValueOrDie(); - std::unique_ptr literal; + Literal literal; TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[0])); - LiteralTestUtil::ExpectR1Equal({1.0, 2.0, 3.0, 4.0}, *literal); + LiteralTestUtil::ExpectR1Equal({1.0, 2.0, 3.0, 4.0}, literal); TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[1])); - LiteralTestUtil::ExpectR1Equal({2.0, 4.0, 6.0, 8.0}, *literal); + LiteralTestUtil::ExpectR1Equal({2.0, 4.0, 6.0, 8.0}, literal); TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[2])); - LiteralTestUtil::ExpectR1Equal({2.0, 4.0, 6.0, 8.0}, *literal); + LiteralTestUtil::ExpectR1Equal({2.0, 4.0, 6.0, 8.0}, literal); TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[3])); - LiteralTestUtil::ExpectR1Equal({1.0, 2.0, 3.0, 4.0}, *literal); + LiteralTestUtil::ExpectR1Equal({1.0, 2.0, 3.0, 4.0}, literal); } TEST_F(DeconstructTupleTest, DeconstructTupleThenDeallocate) { @@ -142,19 +142,19 @@ TEST_F(DeconstructTupleTest, DeconstructTupleThenDeallocate) { // should not have been deallocated because of reference counting. global_data.reset(); - std::unique_ptr literal; + Literal literal; TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[0])); - LiteralTestUtil::ExpectR1Equal({1.0, 2.0, 3.0, 4.0}, *literal); + LiteralTestUtil::ExpectR1Equal({1.0, 2.0, 3.0, 4.0}, literal); TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[1])); - LiteralTestUtil::ExpectR1Equal({2.0, 4.0, 6.0, 8.0}, *literal); + LiteralTestUtil::ExpectR1Equal({2.0, 4.0, 6.0, 8.0}, literal); TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[2])); - LiteralTestUtil::ExpectR1Equal({1.0, 2.0, 3.0, 4.0}, *literal); + LiteralTestUtil::ExpectR1Equal({1.0, 2.0, 3.0, 4.0}, literal); /// Try deallocating one of the repeated elements, then copy handles[0].reset(); TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[2])); - LiteralTestUtil::ExpectR1Equal({1.0, 2.0, 3.0, 4.0}, *literal); + LiteralTestUtil::ExpectR1Equal({1.0, 2.0, 3.0, 4.0}, literal); } TEST_F(DeconstructTupleTest, DeconstructNonTuple) { @@ -170,10 +170,9 @@ TEST_F(DeconstructTupleTest, DeconstructNonTuple) { XLA_TEST_F(DeconstructTupleTest, DeconstructTupleFromParam) { XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = - LiteralUtil::CreateR1({3.14f, -100.25f}); + Literal param0_literal = LiteralUtil::CreateR1({3.14f, -100.25f}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); auto p = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2}), "param0"); Tuple(&builder, {p}); auto global_data = ExecuteAndCheckTransfer(&builder, {param0_data.get()}); diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc index 5873516442fa63de47360acaa353abb3a97fe881..6c0847a875798870b4362a99ac2ab65d99f9f3e6 100644 --- a/tensorflow/compiler/xla/tests/dot_operation_test.cc +++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc @@ -68,16 +68,16 @@ XLA_TEST_F(DotOperationTest, DotOfInputTupleElem) { XlaOp param; auto param_data = CreateParameterAndTransferLiteral( 0, - *LiteralUtil::MakeTuple( - {LiteralUtil::CreateR2({{1, 2}, {3, 4}}).get(), - LiteralUtil::CreateR2({{5, 6}, {7, 8}}).get()}), + LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR2({{1, 2}, {3, 4}}), + LiteralUtil::CreateR2({{5, 6}, {7, 8}})}), "arg0", &builder, ¶m); auto lhs = GetTupleElement(param, 0); auto rhs = GetTupleElement(param, 1); Dot(lhs, rhs); ComputeAndCompareLiteral(&builder, - *LiteralUtil::CreateR2({{19, 22}, {43, 50}}), + LiteralUtil::CreateR2({{19, 22}, {43, 50}}), {param_data.get()}); } @@ -196,11 +196,11 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, FusedDot) { auto lhs_handle = this->client_ - ->TransferToServer(*LiteralUtil::CreateR2FromArray2D( + ->TransferToServer(LiteralUtil::CreateR2FromArray2D( {{1.0f, 2.0f, 3.0f, 4.0f}, {-1.0f, -2.0f, -3.0f, -4.0f}})) .ConsumeValueOrDie(); auto rhs_handle = this->client_ - ->TransferToServer(*LiteralUtil::CreateR2FromArray2D( + ->TransferToServer(LiteralUtil::CreateR2FromArray2D( {{1.0f}, {2.0f}, {3.0f}, {4.0f}})) .ConsumeValueOrDie(); @@ -219,14 +219,14 @@ class SquareMatrixDot : public DotOperationTest { void TestImpl(bool lhs_row_major, bool rhs_row_major) { auto lhs_handle = client_ - ->TransferToServer(*LiteralUtil::CreateFromArrayWithLayout( + ->TransferToServer(LiteralUtil::CreateFromArrayWithLayout( {{1.0f, 2.0f}, {3.0f, -4.0f}}, LayoutUtil::MakeLayout( MinorToMajorForIsRowMajor(lhs_row_major)))) .ConsumeValueOrDie(); auto rhs_handle = client_ - ->TransferToServer(*LiteralUtil::CreateFromArrayWithLayout( + ->TransferToServer(LiteralUtil::CreateFromArrayWithLayout( {{1.0f, 6.0f}, {7.0f, -4.0f}}, LayoutUtil::MakeLayout( MinorToMajorForIsRowMajor(rhs_row_major)))) @@ -286,24 +286,23 @@ void ParametricDotTest::TestImpl() { std::unique_ptr> dot_lhs_data = MakeLinspaceArray2D(0.0, 1.0, param.m, param.k); - std::unique_ptr dot_lhs_lit = - LiteralUtil::CreateR2FromArray2DWithLayout( - *dot_lhs_data, LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor( - param.dot_lhs_row_major))); + Literal dot_lhs_lit = LiteralUtil::CreateR2FromArray2DWithLayout( + *dot_lhs_data, LayoutUtil::MakeLayout( + MinorToMajorForIsRowMajor(param.dot_lhs_row_major))); std::unique_ptr dot_lhs_handle = - client_->TransferToServer(*dot_lhs_lit).ConsumeValueOrDie(); + client_->TransferToServer(dot_lhs_lit).ConsumeValueOrDie(); std::unique_ptr> dot_rhs_data = MakeLinspaceArray2D(0.0, 1.0, param.k, param.n); Layout rhs_layout = LayoutUtil::MakeLayout( MinorToMajorForIsRowMajor(param.dot_rhs_row_major)); - std::unique_ptr dot_rhs_lit = + Literal dot_rhs_lit = LiteralUtil::CreateR2FromArray2DWithLayout(*dot_rhs_data, rhs_layout); std::unique_ptr dot_rhs_handle = - client_->TransferToServer(*dot_rhs_lit).ConsumeValueOrDie(); + client_->TransferToServer(dot_rhs_lit).ConsumeValueOrDie(); std::unique_ptr> addend_data; - std::unique_ptr addend_lit; + Literal addend_lit; std::unique_ptr addend_handle; if (param.has_addend) { @@ -311,7 +310,7 @@ void ParametricDotTest::TestImpl() { addend_lit = LiteralUtil::CreateR2FromArray2DWithLayout( *addend_data, LayoutUtil::MakeLayout( MinorToMajorForIsRowMajor(param.addend_row_major))); - addend_handle = client_->TransferToServer(*addend_lit).ConsumeValueOrDie(); + addend_handle = client_->TransferToServer(addend_lit).ConsumeValueOrDie(); } XlaBuilder builder(TestName()); @@ -395,6 +394,10 @@ class ParametricDotTestWithoutLayoutAssignment : public ParametricDotTest { ParametricDotTestWithoutLayoutAssignment() { execution_options_.mutable_debug_options()->add_xla_disable_hlo_passes( "layout-assignment"); + // Disable algebraic simplification because the pass may replace a dot + // instruction with a layout-changing multiplication instruction. + execution_options_.mutable_debug_options()->add_xla_disable_hlo_passes( + "algsimp"); } }; @@ -405,31 +408,18 @@ std::vector CreateNoLayoutAssignmentDotTestParameters() { for (bool lhs_row_major : {true, false}) { for (bool rhs_row_major : {true, false}) { for (bool has_addend : {true, false}) { + // The addend needs to be row major to match the result of the dot. params.push_back({/*m=*/1, /*k=*/k, /*n=*/n, /*dot_lhs_row_major=*/lhs_row_major, /*dot_rhs_row_major=*/rhs_row_major, /*has_addend=*/has_addend, /*addend_row_major=*/true}); - if (has_addend) { - params.push_back({/*m=*/1, /*k=*/k, /*n=*/n, - /*dot_lhs_row_major=*/lhs_row_major, - /*dot_rhs_row_major=*/rhs_row_major, - /*has_addend=*/has_addend, - /*addend_row_major=*/false}); - } if (n != 1) { params.push_back({/*m=*/n, /*k=*/k, /*n=*/1, /*dot_lhs_row_major=*/lhs_row_major, /*dot_rhs_row_major=*/rhs_row_major, /*has_addend=*/has_addend, /*addend_row_major=*/true}); - if (has_addend) { - params.push_back({/*m=*/n, /*k=*/k, /*n=*/1, - /*dot_lhs_row_major=*/lhs_row_major, - /*dot_rhs_row_major=*/rhs_row_major, - /*has_addend=*/has_addend, - /*addend_row_major=*/false}); - } } } } @@ -477,14 +467,14 @@ class NonsquareMatrixDot : public DotOperationTest { void TestImpl(bool lhs_row_major, bool rhs_row_major) { auto lhs_handle = client_ - ->TransferToServer(*LiteralUtil::CreateFromArrayWithLayout( + ->TransferToServer(LiteralUtil::CreateFromArrayWithLayout( {{1.0f, 2.0f, 3.0f}, {3.0f, -4.0f, -1.0f}}, LayoutUtil::MakeLayout( MinorToMajorForIsRowMajor(lhs_row_major)))) .ConsumeValueOrDie(); auto rhs_handle = client_ - ->TransferToServer(*LiteralUtil::CreateFromArrayWithLayout( + ->TransferToServer(LiteralUtil::CreateFromArrayWithLayout( {{1.0f, 6.0f}, {2.0f, 3.0f}, {7.0f, -4.0f}}, LayoutUtil::MakeLayout( MinorToMajorForIsRowMajor(rhs_row_major)))) @@ -511,12 +501,12 @@ XLA_TYPED_TEST(NonsquareMatrixDot, TestTT) { this->TestImpl(true, true); } XLA_TEST_F(DotOperationTest, MatrixVectorC64) { auto lhs_handle = client_ - ->TransferToServer(*LiteralUtil::CreateR2WithLayout( + ->TransferToServer(LiteralUtil::CreateR2WithLayout( {{1.0, 2.0, 3.0, -4.0}}, LayoutUtil::MakeLayout({1, 0}))) .ConsumeValueOrDie(); auto rhs_handle = client_ - ->TransferToServer(*LiteralUtil::CreateR2WithLayout( + ->TransferToServer(LiteralUtil::CreateR2WithLayout( {{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}, {-4.0, 4.0}}, LayoutUtil::MakeLayout({1, 0}))) .ConsumeValueOrDie(); @@ -584,7 +574,7 @@ XLA_TYPED_TEST(DotOperationTestForBatchMatMul, Types) { Reshape(out_flat, {0, 1, 2}, {2, 2, 2, 2}); auto x_data = this->client_ - ->TransferToServer(*LiteralUtil::CreateR4FromArray4D( + ->TransferToServer(LiteralUtil::CreateR4FromArray4D( {{{{1000.0f, 100.0f}, {10.0f, 1.0f}}, {{2000.0f, 200.0f}, {20.0f, 2.0f}}}, {{{3000.0f, 300.0f}, {30.0f, 3.0f}}, @@ -592,7 +582,7 @@ XLA_TYPED_TEST(DotOperationTestForBatchMatMul, Types) { .ConsumeValueOrDie(); auto y_data = this->client_ - ->TransferToServer(*LiteralUtil::CreateR4FromArray4D( + ->TransferToServer(LiteralUtil::CreateR4FromArray4D( {{{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}}, {{{11.0f, 22.0f}, {33.0f, 44.0f}}, {{55.0f, 66.0f}, {77.0f, 88.0f}}}})) @@ -630,13 +620,13 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, GeneralMatMul) { auto x_data = this->client_ - ->TransferToServer(*LiteralUtil::CreateR3FromArray3D( + ->TransferToServer(LiteralUtil::CreateR3FromArray3D( {{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}})) .ConsumeValueOrDie(); auto y_data = this->client_ - ->TransferToServer(*LiteralUtil::CreateR3FromArray3D( + ->TransferToServer(LiteralUtil::CreateR3FromArray3D( {{{1.0f, 0.0f}, {0.0f, 1.0f}}, {{1.0f, 0.0f}, {0.0f, 1.0f}}})) .ConsumeValueOrDie(); @@ -668,7 +658,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, GeneralMatMulMultipleBatch) { auto x_data = this->client_ - ->TransferToServer(*LiteralUtil::CreateR4FromArray4D( + ->TransferToServer(LiteralUtil::CreateR4FromArray4D( {{{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}}, {{{9.0f, 10.0f}, {11.0f, 12.0f}}, {{13.0f, 14.0f}, {15.0f, 16.0f}}}})) @@ -676,7 +666,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, GeneralMatMulMultipleBatch) { auto y_data = this->client_ - ->TransferToServer(*LiteralUtil::CreateR4FromArray4D( + ->TransferToServer(LiteralUtil::CreateR4FromArray4D( {{{{1.0f, 0.0f}, {0.0f, 1.0f}}, {{1.0f, 0.0f}, {0.0f, 1.0f}}}, {{{0.0f, 1.0f}, {1.0f, 0.0f}}, {{0.0f, 1.0f}, {1.0f, 0.0f}}}})) .ConsumeValueOrDie(); @@ -708,14 +698,14 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, TransposeFolding) { auto lhs_handle = this->client_ ->TransferToServer( - *LiteralUtil::CreateR2FromArray2DWithLayout( + LiteralUtil::CreateR2FromArray2DWithLayout( *lhs, LayoutUtil::MakeLayout( MinorToMajorForIsRowMajor(row_major)))) .ConsumeValueOrDie(); auto rhs_handle = this->client_ ->TransferToServer( - *LiteralUtil::CreateR2FromArray2DWithLayout( + LiteralUtil::CreateR2FromArray2DWithLayout( *rhs, LayoutUtil::MakeLayout( MinorToMajorForIsRowMajor(row_major)))) .ConsumeValueOrDie(); @@ -778,15 +768,15 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, TF_ASSERT_OK_AND_ASSIGN( auto arg_0_value, this->client_->TransferToServer( - *LiteralUtil::CreateR2FromArray2D(*arg_0_value_array))); + LiteralUtil::CreateR2FromArray2D(*arg_0_value_array))); TF_ASSERT_OK_AND_ASSIGN( auto arg_1_value, this->client_->TransferToServer( - *LiteralUtil::CreateR2FromArray2D(*arg_1_value_array))); + LiteralUtil::CreateR2FromArray2D(*arg_1_value_array))); TF_ASSERT_OK_AND_ASSIGN( auto arg_2_value, this->client_->TransferToServer( - *LiteralUtil::CreateR2FromArray2D(*arg_2_value_array))); + LiteralUtil::CreateR2FromArray2D(*arg_2_value_array))); Array2D expected({{53.0f, 74.0f}, {45.0f, 66.0f}}); this->template ComputeAndCompareR2( @@ -827,15 +817,15 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, TF_ASSERT_OK_AND_ASSIGN( auto arg_0_value, this->client_->TransferToServer( - *LiteralUtil::CreateR2FromArray2D(*arg_0_value_array))); + LiteralUtil::CreateR2FromArray2D(*arg_0_value_array))); TF_ASSERT_OK_AND_ASSIGN( auto arg_1_value, this->client_->TransferToServer( - *LiteralUtil::CreateR2FromArray2D(*arg_1_value_array))); + LiteralUtil::CreateR2FromArray2D(*arg_1_value_array))); TF_ASSERT_OK_AND_ASSIGN( auto arg_2_value, this->client_->TransferToServer( - *LiteralUtil::CreateR2FromArray2D(*arg_2_value_array))); + LiteralUtil::CreateR2FromArray2D(*arg_2_value_array))); Array2D expected({{38.0f, 36.0f}, {93.0f, 91.0f}}); this->template ComputeAndCompareR2( diff --git a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc index 7f6f203a1ba48e0053f799c58bbbeae87aef1f7f..7501c6d957e7afe99b8c530e5f0d575f818367da 100644 --- a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc +++ b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc @@ -114,23 +114,23 @@ class DynamicSliceTest : public ClientLibraryTestBase { } template - void RunR1(tensorflow::gtl::ArraySlice input_values_int, + void RunR1(absl::Span input_values_int, const std::vector slice_starts, const std::vector& slice_sizes, - tensorflow::gtl::ArraySlice expected_values_int) { + absl::Span expected_values_int) { // bfloat16 has explicit constructors, so it does not implicitly convert the // way built-in types do, which is why we can't take the parameter as an - // ArraySlice. We also can't convert it to a vector, because - // vector is special so that it cannot be an ArraySlice, which + // Span. We also can't convert it to a vector, because + // vector is special so that it cannot be a Span, which // is what the code below wants. So instead we do this. Literal input_values = - std::move(*LiteralUtil::CreateR1(input_values_int) - ->Convert(primitive_util::NativeToPrimitiveType()) - .ValueOrDie()); + LiteralUtil::CreateR1(input_values_int) + .Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie(); Literal expected_values = - std::move(*LiteralUtil::CreateR1(expected_values_int) - ->Convert(primitive_util::NativeToPrimitiveType()) - .ValueOrDie()); + std::move(LiteralUtil::CreateR1(expected_values_int) + .Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); XlaBuilder builder(TestName()); // Initialize and transfer dynamic slice start indices parameter. @@ -150,13 +150,13 @@ class DynamicSliceTest : public ClientLibraryTestBase { const std::vector& slice_sizes, const Array2D& expected_values_int) { Literal input_values = - std::move(*LiteralUtil::CreateR2FromArray2D(input_values_int) - ->Convert(primitive_util::NativeToPrimitiveType()) - .ValueOrDie()); + std::move(LiteralUtil::CreateR2FromArray2D(input_values_int) + .Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); Literal expected_values = - std::move(*LiteralUtil::CreateR2FromArray2D(expected_values_int) - ->Convert(primitive_util::NativeToPrimitiveType()) - .ValueOrDie()); + std::move(LiteralUtil::CreateR2FromArray2D(expected_values_int) + .Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); XlaBuilder builder(TestName()); // Initialize and transfer dynamic slice start indices parameter. @@ -176,13 +176,13 @@ class DynamicSliceTest : public ClientLibraryTestBase { const std::vector& slice_sizes, const Array3D& expected_values_int) { Literal input_values = - std::move(*LiteralUtil::CreateR3FromArray3D(input_values_int) - ->Convert(primitive_util::NativeToPrimitiveType()) - .ValueOrDie()); + std::move(LiteralUtil::CreateR3FromArray3D(input_values_int) + .Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); Literal expected_values = - std::move(*LiteralUtil::CreateR3FromArray3D(expected_values_int) - ->Convert(primitive_util::NativeToPrimitiveType()) - .ValueOrDie()); + std::move(LiteralUtil::CreateR3FromArray3D(expected_values_int) + .Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); XlaBuilder builder(TestName()); // Initialize and transfer dynamic slice start indices parameter. @@ -359,17 +359,17 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { void RunR0(int input_value_int, int update_value_int, const std::vector slice_starts, int expected_value_int) { Literal input_value = - std::move(*LiteralUtil::CreateR0(input_value_int) - ->Convert(primitive_util::NativeToPrimitiveType()) - .ValueOrDie()); + std::move(LiteralUtil::CreateR0(input_value_int) + .Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); Literal update_value = - std::move(*LiteralUtil::CreateR0(update_value_int) - ->Convert(primitive_util::NativeToPrimitiveType()) - .ValueOrDie()); + std::move(LiteralUtil::CreateR0(update_value_int) + .Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); Literal expected_value = - std::move(*LiteralUtil::CreateR0(expected_value_int) - ->Convert(primitive_util::NativeToPrimitiveType()) - .ValueOrDie()); + std::move(LiteralUtil::CreateR0(expected_value_int) + .Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); XlaBuilder builder(TestName()); // Initialize and transfer dynamic slice start indices parameter. @@ -385,22 +385,22 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { } template - void RunR1(tensorflow::gtl::ArraySlice input_values_int, - tensorflow::gtl::ArraySlice update_values_int, + void RunR1(absl::Span input_values_int, + absl::Span update_values_int, const std::vector slice_starts, - tensorflow::gtl::ArraySlice expected_values_int) { + absl::Span expected_values_int) { Literal input_values = - std::move(*LiteralUtil::CreateR1(input_values_int) - ->Convert(primitive_util::NativeToPrimitiveType()) - .ValueOrDie()); + std::move(LiteralUtil::CreateR1(input_values_int) + .Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); Literal update_values = - std::move(*LiteralUtil::CreateR1(update_values_int) - ->Convert(primitive_util::NativeToPrimitiveType()) - .ValueOrDie()); + std::move(LiteralUtil::CreateR1(update_values_int) + .Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); Literal expected_values = - std::move(*LiteralUtil::CreateR1(expected_values_int) - ->Convert(primitive_util::NativeToPrimitiveType()) - .ValueOrDie()); + std::move(LiteralUtil::CreateR1(expected_values_int) + .Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); XlaBuilder builder(TestName()); // Initialize and transfer dynamic slice start indices parameter. @@ -421,17 +421,17 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { const std::vector slice_starts, const Array2D& expected_values_int) { Literal input_values = - std::move(*LiteralUtil::CreateR2FromArray2D(input_values_int) - ->Convert(primitive_util::NativeToPrimitiveType()) - .ValueOrDie()); + std::move(LiteralUtil::CreateR2FromArray2D(input_values_int) + .Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); Literal update_values = - std::move(*LiteralUtil::CreateR2FromArray2D(update_values_int) - ->Convert(primitive_util::NativeToPrimitiveType()) - .ValueOrDie()); + std::move(LiteralUtil::CreateR2FromArray2D(update_values_int) + .Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); Literal expected_values = - std::move(*LiteralUtil::CreateR2FromArray2D(expected_values_int) - ->Convert(primitive_util::NativeToPrimitiveType()) - .ValueOrDie()); + std::move(LiteralUtil::CreateR2FromArray2D(expected_values_int) + .Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); XlaBuilder builder(TestName()); // Initialize and transfer dynamic slice start indices parameter. @@ -452,17 +452,17 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { const std::vector slice_starts, const Array3D& expected_values_int) { Literal input_values = - std::move(*LiteralUtil::CreateR3FromArray3D(input_values_int) - ->Convert(primitive_util::NativeToPrimitiveType()) - .ValueOrDie()); + std::move(LiteralUtil::CreateR3FromArray3D(input_values_int) + .Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); Literal update_values = - std::move(*LiteralUtil::CreateR3FromArray3D(update_values_int) - ->Convert(primitive_util::NativeToPrimitiveType()) - .ValueOrDie()); + std::move(LiteralUtil::CreateR3FromArray3D(update_values_int) + .Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); Literal expected_values = - std::move(*LiteralUtil::CreateR3FromArray3D(expected_values_int) - ->Convert(primitive_util::NativeToPrimitiveType()) - .ValueOrDie()); + std::move(LiteralUtil::CreateR3FromArray3D(expected_values_int) + .Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); XlaBuilder builder(TestName()); // Initialize and transfer dynamic slice start indices parameter. @@ -529,9 +529,8 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { template void DumpArray(const string& name, const Array3D values) { - std::unique_ptr literal = - LiteralUtil::CreateR3FromArray3D(values); - LOG(INFO) << name << ":" << literal->ToString(); + Literal literal = LiteralUtil::CreateR3FromArray3D(values); + LOG(INFO) << name << ":" << literal.ToString(); } }; @@ -719,7 +718,7 @@ void BM_DynamicSlice(int num_iters) { auto input_literal = LiteralUtil::CreateR4( {{{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}, {{13, 14, 15, 16}, {17, 18, 19, 20}, {21, 22, 23, 24}}}}); - auto input = ConstantLiteral(&builder, *input_literal); + auto input = ConstantLiteral(&builder, input_literal); // Create dynamic slice start indices as a parameter: shape [4] auto start_indices_shape = ShapeUtil::MakeShape(S32, {4}); @@ -740,7 +739,7 @@ void BM_DynamicSlice(int num_iters) { auto stream = client->mutable_backend()->BorrowStream(device_ordinal).ValueOrDie(); ASSERT_IS_OK(transfer_manager->TransferLiteralToDevice( - stream.get(), *start_indices_literal, buffer)); + stream.get(), start_indices_literal, buffer)); std::unique_ptr executable = client diff --git a/tensorflow/compiler/xla/tests/execution_profile_test.cc b/tensorflow/compiler/xla/tests/execution_profile_test.cc index 5116e60ca63ef5f94b25b15e6616086fb9e44bbb..b08ece0e63e9472f657b49b533511e9b192d3212 100644 --- a/tensorflow/compiler/xla/tests/execution_profile_test.cc +++ b/tensorflow/compiler/xla/tests/execution_profile_test.cc @@ -31,7 +31,7 @@ XLA_TEST_F(ExecutionProfileTest, ExecuteWithExecutionProfile) { TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr input, client_->TransferToServer( - *LiteralUtil::CreateR2F32Linspace(1e0, 1e5, 256, 256))); + LiteralUtil::CreateR2F32Linspace(1e0, 1e5, 256, 256))); XlaBuilder b(TestName() + ".add"); Dot(Parameter(&b, 0, shape, "param_0"), Parameter(&b, 1, shape, "param_1")); diff --git a/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc b/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc index bf1de02ba9dbd97db9ee31484402fe9b92385219..51b50d456e496c9c01c38fb8539bb3737de16937 100644 --- a/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc +++ b/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc @@ -38,29 +38,29 @@ class ExhaustiveF32ElementwiseOpTest XlaBuilder builder(TestName()); - std::unique_ptr input_literal = + Literal input_literal = LiteralUtil::CreateFromDimensions(F32, {input_size}); for (int64 i = begin; i < end; i++) { if (i >= known_incorrect_range.first && i < known_incorrect_range.second) { // If the operation is known to be buggy on a specific input clamp that // input to 0 under the assumption that the op is at least correct on 0. - input_literal->Set({i - begin}, 0.0f); + input_literal.Set({i - begin}, 0.0f); } else { - input_literal->Set({i - begin}, tensorflow::bit_cast(i)); + input_literal.Set({i - begin}, tensorflow::bit_cast(i)); } } TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr input_data, - client_->TransferToServer(*input_literal)); + client_->TransferToServer(input_literal)); - auto input = Parameter(&builder, 0, input_literal->shape(), "input"); + auto input = Parameter(&builder, 0, input_literal.shape(), "input"); enqueue_op(&builder, input); std::vector expected_result; expected_result.reserve(input_size); for (int64 i = 0; i < input_size; i++) { - expected_result.push_back(evaluate_op(input_literal->Get({i}))); + expected_result.push_back(evaluate_op(input_literal.Get({i}))); } ComputeAndCompareR1(&builder, expected_result, {input_data.get()}, diff --git a/tensorflow/compiler/xla/tests/floor_ceil_test.cc b/tensorflow/compiler/xla/tests/floor_ceil_test.cc index 4a835a8e219d4b64fa144e12e9b4cbc41f45946f..3be9657db40a7ea073baca32d8a20ccd6fa8a274 100644 --- a/tensorflow/compiler/xla/tests/floor_ceil_test.cc +++ b/tensorflow/compiler/xla/tests/floor_ceil_test.cc @@ -17,12 +17,12 @@ limitations under the License. #include #include "absl/strings/str_join.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/test.h" @@ -37,8 +37,8 @@ class FloorCeilTest : public ClientLibraryTestBase { }; // Runs a computation and comparison on expected vs f(input) - void TestR1F32(tensorflow::gtl::ArraySlice input, - tensorflow::gtl::ArraySlice expected, Function f) { + void TestR1F32(absl::Span input, + absl::Span expected, Function f) { LOG(INFO) << "input: {" << absl::StrJoin(expected, ", ") << "}"; XlaBuilder builder(TestName()); auto c = ConstantR1(&builder, input); diff --git a/tensorflow/compiler/xla/tests/fusion_test.cc b/tensorflow/compiler/xla/tests/fusion_test.cc index 341124170a5f6768720032394c42205f9185920a..4d4b676a538947c8dd92a7e34db72e45766cae2c 100644 --- a/tensorflow/compiler/xla/tests/fusion_test.cc +++ b/tensorflow/compiler/xla/tests/fusion_test.cc @@ -23,6 +23,7 @@ limitations under the License. #define EIGEN_USE_THREADS #include "absl/memory/memory.h" +#include "absl/types/span.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/client_library.h" @@ -42,14 +43,11 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/common_runtime/eigen_thread_pool.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/test_benchmark.h" #include "tensorflow/core/platform/types.h" -using tensorflow::gtl::ArraySlice; - namespace xla { namespace { @@ -113,26 +111,26 @@ class FusionTest : public HloTestBase { hlos[0] = builder.AddInstruction(std::move(root_hlo)); hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction( - ArraySlice(hlos, 0, Arity + 1), + absl::Span(hlos).subspan(0, Arity + 1), HloInstruction::FusionKind::kLoop); auto expected = LiteralUtil::CreateR2FromArray2D(answer_data); auto actual = ExecuteAndTransfer(std::move(hlo_module), {}); if (primitive_util::IsFloatingPointType(prim_type)) { - EXPECT_TRUE(LiteralTestUtil::Near(*expected, *actual, ErrorSpec(1e-4))); + EXPECT_TRUE(LiteralTestUtil::Near(expected, actual, ErrorSpec(1e-4))); } else { - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *actual)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, actual)); } } private: template - T ComputeElementwiseAnswer(HloOpcode opcode, ArraySlice xs); + T ComputeElementwiseAnswer(HloOpcode opcode, absl::Span xs); }; template <> float FusionTest::ComputeElementwiseAnswer(HloOpcode opcode, - ArraySlice xs) { + absl::Span xs) { switch (opcode) { case HloOpcode::kAdd: return xs[0] + xs[1]; @@ -157,7 +155,7 @@ float FusionTest::ComputeElementwiseAnswer(HloOpcode opcode, template <> bool FusionTest::ComputeElementwiseAnswer(HloOpcode opcode, - ArraySlice xs) { + absl::Span xs) { switch (opcode) { case HloOpcode::kEq: return xs[0] == xs[1]; @@ -224,8 +222,8 @@ XLA_TEST_F(FusionTest, Test) { HloInstruction::FusionKind::kLoop); EXPECT_TRUE(LiteralTestUtil::Near( - *LiteralUtil::CreateR2({{0.5}, {2.72}}), - *ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4))); + LiteralUtil::CreateR2({{0.5}, {2.72}}), + ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4))); } // Test whether we emit appropriate code for parameters of fusion instructions. @@ -250,8 +248,8 @@ XLA_TEST_F(FusionTest, Parameter) { HloInstruction::FusionKind::kLoop); EXPECT_TRUE(LiteralTestUtil::Near( - *LiteralUtil::CreateR2({{-1.0, 0.0, 1.0}}), - *ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4))); + LiteralUtil::CreateR2({{-1.0, 0.0, 1.0}}), + ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4))); } XLA_TEST_F(FusionTest, RandomizedParallelPartition) { @@ -285,7 +283,7 @@ XLA_TEST_F(FusionTest, RandomizedParallelPartition) { // Every element of result should be y = x^2 = 4.0. for (int i = 0; i < rand_dim0_size; ++i) { for (int j = 0; j < dim1_size; ++j) { - EXPECT_EQ(4.0, result->Get({i, j})); + EXPECT_EQ(4.0, result.Get({i, j})); } } } @@ -310,8 +308,8 @@ XLA_TEST_F(FusionTest, BroadcastIntoBinaryOp) { HloInstruction::FusionKind::kLoop); EXPECT_TRUE(LiteralTestUtil::Near( - *LiteralUtil::CreateR2({{0.0, 0.0, -1.0}, {11.0, 22.0, 33.0}}), - *ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4))); + LiteralUtil::CreateR2({{0.0, 0.0, -1.0}, {11.0, 22.0, 33.0}}), + ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4))); } XLA_TEST_F(FusionTest, ReshapeToScalar) { @@ -325,8 +323,8 @@ XLA_TEST_F(FusionTest, ReshapeToScalar) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape}, HloInstruction::FusionKind::kLoop); EXPECT_TRUE( - LiteralTestUtil::Equal(*LiteralUtil::CreateR0(5), - *ExecuteAndTransfer(std::move(hlo_module), {}))); + LiteralTestUtil::Equal(LiteralUtil::CreateR0(5), + ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, Reshape_3by2_1by2by3) { @@ -340,8 +338,8 @@ XLA_TEST_F(FusionTest, Reshape_3by2_1by2by3) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::CreateR3({{{1, 2, 3}, {4, 5, 6}}}), - *ExecuteAndTransfer(std::move(hlo_module), {}))); + LiteralUtil::CreateR3({{{1, 2, 3}, {4, 5, 6}}}), + ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, Reshape_1by2by3_3by2) { @@ -355,8 +353,8 @@ XLA_TEST_F(FusionTest, Reshape_1by2by3_3by2) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::CreateR2({{1, 2}, {3, 4}, {5, 6}}), - *ExecuteAndTransfer(std::move(hlo_module), {}))); + LiteralUtil::CreateR2({{1, 2}, {3, 4}, {5, 6}}), + ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, Reshape_1by1by1_) { @@ -370,8 +368,8 @@ XLA_TEST_F(FusionTest, Reshape_1by1by1_) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); EXPECT_TRUE( - LiteralTestUtil::Equal(*LiteralUtil::CreateR0(7), - *ExecuteAndTransfer(std::move(hlo_module), {}))); + LiteralTestUtil::Equal(LiteralUtil::CreateR0(7), + ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, Reshape__1by1by1) { @@ -385,8 +383,8 @@ XLA_TEST_F(FusionTest, Reshape__1by1by1) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); EXPECT_TRUE( - LiteralTestUtil::Equal(*LiteralUtil::CreateR3({{{7}}}), - *ExecuteAndTransfer(std::move(hlo_module), {}))); + LiteralTestUtil::Equal(LiteralUtil::CreateR3({{{7}}}), + ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, Reshape__) { @@ -400,8 +398,8 @@ XLA_TEST_F(FusionTest, Reshape__) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); EXPECT_TRUE( - LiteralTestUtil::Equal(*LiteralUtil::CreateR0(7), - *ExecuteAndTransfer(std::move(hlo_module), {}))); + LiteralTestUtil::Equal(LiteralUtil::CreateR0(7), + ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, Reshape_3by3_3by3) { @@ -415,8 +413,8 @@ XLA_TEST_F(FusionTest, Reshape_3by3_3by3) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}), - *ExecuteAndTransfer(std::move(hlo_module), {}))); + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}), + ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, Transpose_2by3) { @@ -430,8 +428,8 @@ XLA_TEST_F(FusionTest, Transpose_2by3) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::CreateR2({{1, 4}, {2, 5}, {3, 6}}), - *ExecuteAndTransfer(std::move(hlo_module), {}))); + LiteralUtil::CreateR2({{1, 4}, {2, 5}, {3, 6}}), + ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, Transpose_3by3) { @@ -445,8 +443,8 @@ XLA_TEST_F(FusionTest, Transpose_3by3) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::CreateR2({{1, 4, 7}, {2, 5, 8}, {3, 6, 9}}), - *ExecuteAndTransfer(std::move(hlo_module), {}))); + LiteralUtil::CreateR2({{1, 4, 7}, {2, 5, 8}, {3, 6, 9}}), + ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, Reverse) { @@ -461,8 +459,8 @@ XLA_TEST_F(FusionTest, Reverse) { HloInstruction::FusionKind::kLoop); EXPECT_TRUE( - LiteralTestUtil::Equal(*LiteralUtil::CreateR1({3, 2, 1}), - *ExecuteAndTransfer(std::move(hlo_module), {}))); + LiteralTestUtil::Equal(LiteralUtil::CreateR1({3, 2, 1}), + ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, ReverseNegate) { @@ -479,8 +477,8 @@ XLA_TEST_F(FusionTest, ReverseNegate) { HloInstruction::FusionKind::kLoop); EXPECT_TRUE( - LiteralTestUtil::Equal(*LiteralUtil::CreateR1({-3, -2, -1}), - *ExecuteAndTransfer(std::move(hlo_module), {}))); + LiteralTestUtil::Equal(LiteralUtil::CreateR1({-3, -2, -1}), + ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, BroadcastNegate) { @@ -497,8 +495,8 @@ XLA_TEST_F(FusionTest, BroadcastNegate) { HloInstruction::FusionKind::kLoop); EXPECT_TRUE( - LiteralTestUtil::Equal(*LiteralUtil::CreateR1({-1, -1}), - *ExecuteAndTransfer(std::move(hlo_module), {}))); + LiteralTestUtil::Equal(LiteralUtil::CreateR1({-1, -1}), + ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, SliceNegate) { @@ -515,8 +513,8 @@ XLA_TEST_F(FusionTest, SliceNegate) { HloInstruction::FusionKind::kLoop); EXPECT_TRUE( - LiteralTestUtil::Equal(*LiteralUtil::CreateR1({-1, -3}), - *ExecuteAndTransfer(std::move(hlo_module), {}))); + LiteralTestUtil::Equal(LiteralUtil::CreateR1({-1, -3}), + ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, DynamicSliceNegate) { @@ -537,8 +535,8 @@ XLA_TEST_F(FusionTest, DynamicSliceNegate) { HloInstruction::FusionKind::kLoop); EXPECT_TRUE( - LiteralTestUtil::Equal(*LiteralUtil::CreateR1({-2, -3}), - *ExecuteAndTransfer(std::move(hlo_module), {}))); + LiteralTestUtil::Equal(LiteralUtil::CreateR1({-2, -3}), + ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, ReshapeNegate) { @@ -554,9 +552,9 @@ XLA_TEST_F(FusionTest, ReshapeNegate) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, reshape1}, HloInstruction::FusionKind::kLoop); - EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::CreateR2({{-1, -2}, {-3, -4}}), - *ExecuteAndTransfer(std::move(hlo_module), {}))); + EXPECT_TRUE( + LiteralTestUtil::Equal(LiteralUtil::CreateR2({{-1, -2}, {-3, -4}}), + ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, TransposeNegate) { @@ -572,9 +570,9 @@ XLA_TEST_F(FusionTest, TransposeNegate) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, transpose1}, HloInstruction::FusionKind::kLoop); - EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::CreateR2({{-1, -3}, {-2, -4}}), - *ExecuteAndTransfer(std::move(hlo_module), {}))); + EXPECT_TRUE( + LiteralTestUtil::Equal(LiteralUtil::CreateR2({{-1, -3}, {-2, -4}}), + ExecuteAndTransfer(std::move(hlo_module), {}))); } std::unique_ptr MakeReduceTestComputation() { @@ -601,11 +599,11 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(Reduce)) { hlo_module->AddEmbeddedComputation(MakeReduceTestComputation()))); hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction(/*instructions_to_fuse=*/{reduce2}, - HloInstruction::FusionKind::kLoop); + HloInstruction::FusionKind::kInput); EXPECT_TRUE( - LiteralTestUtil::Equal(*LiteralUtil::CreateR0(15), - *ExecuteAndTransfer(std::move(hlo_module), {}))); + LiteralTestUtil::Equal(LiteralUtil::CreateR0(15), + ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceImplicitBroadcast)) { @@ -626,8 +624,8 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceImplicitBroadcast)) { HloInstruction::FusionKind::kLoop); EXPECT_TRUE( - LiteralTestUtil::Equal(*LiteralUtil::CreateR0(-15), - *ExecuteAndTransfer(std::move(hlo_module), {}))); + LiteralTestUtil::Equal(LiteralUtil::CreateR0(-15), + ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceWindow)) { @@ -676,8 +674,8 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceWindow)) { HloInstruction::FusionKind::kLoop); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::CreateR2({{462, 2145}, {24871, 62491}}), - *ExecuteAndTransfer(std::move(hlo_module), {}))); + LiteralUtil::CreateR2({{462, 2145}, {24871, 62491}}), + ExecuteAndTransfer(std::move(hlo_module), {}))); } // When a constant (or other op) which has multiple users is imported @@ -712,8 +710,8 @@ XLA_TEST_F(FusionTest, SharedConstant) { EXPECT_EQ(entry_comp->root_instruction()->fused_instruction_count(), 6); EXPECT_TRUE( - LiteralTestUtil::Equal(*LiteralUtil::CreateR1({8}), - *ExecuteAndTransfer(std::move(hlo_module), {}))); + LiteralTestUtil::Equal(LiteralUtil::CreateR1({8}), + ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, Add2D) { TestElementwise2D(HloOpcode::kAdd); } @@ -766,8 +764,10 @@ XLA_TEST_F(FusionTest, Clamp2D) { TestElementwise2D(HloOpcode::kClamp); } -// TODO(b/73903144): Enable on interpreter once interpreter supports bitcast. -XLA_TEST_F(FusionTest, DISABLED_ON_INTERPRETER(FusionWithLayout)) { +// TODO(b/117156505): Remove this test when the bug is fixed and the CPU backend +// should not generate layout changing elementwise operations. +#ifdef XLA_TEST_BACKEND_CPU +XLA_TEST_F(FusionTest, LayoutChangingElementWiseOp) { const string hlo_text = R"( HloModule Cluster @@ -784,20 +784,19 @@ ENTRY main { } )"; - std::unique_ptr operand = - LiteralUtil::CreateR2({{0., 0.}, {1., 0.}}); + Literal operand = LiteralUtil::CreateR2({{0., 0.}, {1., 0.}}); HloModuleConfig config; config.set_debug_options(GetDebugOptionsForTest()); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseHloString(hlo_text, config)); - TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr result, - test_runner_.Execute(std::move(module), {operand.get()}, - /*run_hlo_passes=*/false)); + TF_ASSERT_OK_AND_ASSIGN(Literal result, + test_runner_.Execute(std::move(module), {&operand}, + /*run_hlo_passes=*/false)); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::CreateR3({{{0.}, {0.76159415595}}, {{0.}, {0.}}}), - *result)); + LiteralUtil::CreateR3({{{0.}, {0.76159415595}}, {{0.}, {0.}}}), + result)); } +#endif class FusionClientLibraryTest : public ClientLibraryTestBase {}; @@ -823,16 +822,16 @@ XLA_TEST_F(FusionClientLibraryTest, ManyLayoutTransformations) { // where overflow is OK. Array2D arr(32, 32); arr.FillUnique(); - std::unique_ptr l1 = LiteralUtil::CreateR2FromArray2D(arr)->Relayout( + Literal l1 = LiteralUtil::CreateR2FromArray2D(arr).Relayout( LayoutUtil::MakeLayout({0, 1})); - std::unique_ptr l2 = LiteralUtil::CreateR2FromArray2D(arr)->Relayout( + Literal l2 = LiteralUtil::CreateR2FromArray2D(arr).Relayout( LayoutUtil::MakeLayout({1, 0})); - XlaOp p0 = AddParam(*l1, &b); + XlaOp p0 = AddParam(l1, &b); XlaOp sum = p0; for (int i = 1; i < kNumParams; ++i) { - auto pN = AddParam((i % 2 == 0 ? *l1 : *l2), &b); + auto pN = AddParam((i % 2 == 0 ? l1 : l2), &b); sum = sum + p0 * pN * pN; } @@ -881,19 +880,19 @@ void BM_ParallelFusion(int num_iters) { auto param0_literal = LiteralUtil::CreateR2F32Linspace(1.0, 2.0, param0_dim0, param0_dim1); ScopedShapedBuffer buffer0 = - client->LiteralToShapedBuffer(*param0_literal, device_ordinal) + client->LiteralToShapedBuffer(param0_literal, device_ordinal) .ConsumeValueOrDie(); auto param1_literal = LiteralUtil::CreateR2F32Linspace(1.0, 2.0, param1_dim0, param1_dim1); ScopedShapedBuffer buffer1 = - client->LiteralToShapedBuffer(*param1_literal, device_ordinal) + client->LiteralToShapedBuffer(param1_literal, device_ordinal) .ConsumeValueOrDie(); auto param2_literal = LiteralUtil::CreateR2F32Linspace(1.0, 2.0, param2_dim0, param2_dim1); ScopedShapedBuffer buffer2 = - client->LiteralToShapedBuffer(*param2_literal, device_ordinal) + client->LiteralToShapedBuffer(param2_literal, device_ordinal) .ConsumeValueOrDie(); // Build executable. diff --git a/tensorflow/compiler/xla/tests/gather_operation_test.cc b/tensorflow/compiler/xla/tests/gather_operation_test.cc index 205d417f0c60e35c71ae6c7ed0a3b099e769f552..daa89398a697af9149797d621c3bdca80a00aedd 100644 --- a/tensorflow/compiler/xla/tests/gather_operation_test.cc +++ b/tensorflow/compiler/xla/tests/gather_operation_test.cc @@ -34,8 +34,7 @@ class GatherOperationTest : public HloTestBase { RunTest(hlo_text, {operand, start_indices}); } - void RunTest(const string& hlo_text, - tensorflow::gtl::ArraySlice args) { + void RunTest(const string& hlo_text, absl::Span args) { HloModuleConfig config; config.set_debug_options(GetDebugOptionsForTest()); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, @@ -59,10 +58,10 @@ ENTRY main { slice_sizes={1, 3} } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr start_indices = LiteralUtil::CreateR1({0, 2}); - RunTest(hlo_text, operand.get(), start_indices.get()); + Literal start_indices = LiteralUtil::CreateR1({0, 2}); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, TensorFlowGatherV2) { @@ -80,10 +79,10 @@ ENTRY main { slice_sizes={3, 1} } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr start_indices = LiteralUtil::CreateR1({0, 2}); - RunTest(hlo_text, operand.get(), start_indices.get()); + Literal start_indices = LiteralUtil::CreateR1({0, 2}); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, TensorFlowGatherMultipleBatchDims) { @@ -101,11 +100,10 @@ ENTRY main { slice_sizes={3, 1} } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr start_indices = - LiteralUtil::CreateR2({{0, 2}, {2, 1}}); - RunTest(hlo_text, operand.get(), start_indices.get()); + Literal start_indices = LiteralUtil::CreateR2({{0, 2}, {2, 1}}); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, TensorFlowGatherNdMultipleBatchDims_0) { @@ -123,11 +121,11 @@ ENTRY main { slice_sizes={1, 1} } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr start_indices = + Literal start_indices = LiteralUtil::CreateR3({{{0, 2}, {2, 1}}, {{1, 2}, {2, 0}}}); - RunTest(hlo_text, operand.get(), start_indices.get()); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, TensorFlowGatherNdMultipleBatchDims_1) { @@ -145,11 +143,11 @@ ENTRY main { slice_sizes={1, 1} } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr start_indices = + Literal start_indices = LiteralUtil::CreateR3({{{0, 2}, {2, 1}}, {{1, 2}, {2, 0}}}); - RunTest(hlo_text, operand.get(), start_indices.get()); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, TensorFlowGatherNd) { @@ -167,13 +165,12 @@ ENTRY main { slice_sizes={1,1,2} } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // {{-4, 4}, {-5, 5}, {-6, 6}}, // {{-7, 7}, {-8, 8}, {-9, 9}}}); - std::unique_ptr start_indices = - LiteralUtil::CreateR2({{0, 0}, {1, 0}}); - RunTest(hlo_text, operand.get(), start_indices.get()); + Literal start_indices = LiteralUtil::CreateR2({{0, 0}, {1, 0}}); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, TensorFlowGatherNdNonDefaultIndexVectorDim) { @@ -191,13 +188,12 @@ ENTRY main { slice_sizes={1,1,2} } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // {{-4, 4}, {-5, 5}, {-6, 6}}, // {{-7, 7}, {-8, 8}, {-9, 9}}}); - std::unique_ptr start_indices = - LiteralUtil::CreateR2({{0, 0}, {1, 0}}); - RunTest(hlo_text, operand.get(), start_indices.get()); + Literal start_indices = LiteralUtil::CreateR2({{0, 0}, {1, 0}}); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, DynamicSlice) { @@ -215,10 +211,10 @@ ENTRY main { slice_sizes={1,1} } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr start_indices = LiteralUtil::CreateR1({1, 1}); - RunTest(hlo_text, operand.get(), start_indices.get()); + Literal start_indices = LiteralUtil::CreateR1({1, 1}); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, BatchDynamicSlice) { @@ -236,11 +232,10 @@ ENTRY main { slice_sizes={1,1} } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr start_indices = - LiteralUtil::CreateR2({{2, 1}, {1, 1}}); - RunTest(hlo_text, operand.get(), start_indices.get()); + Literal start_indices = LiteralUtil::CreateR2({{2, 1}, {1, 1}}); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, ZeroDimBounds) { @@ -258,9 +253,9 @@ ENTRY main { slice_sizes={1, 0} } )"; - std::unique_ptr operand = LiteralUtil::CreateR2({{}, {}, {}}); - std::unique_ptr start_indices = LiteralUtil::CreateR1({0, 2}); - RunTest(hlo_text, operand.get(), start_indices.get()); + Literal operand = LiteralUtil::CreateR2({{}, {}, {}}); + Literal start_indices = LiteralUtil::CreateR1({0, 2}); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, OutOfBoundsIndex) { @@ -282,11 +277,11 @@ ENTRY main { ROOT result = s32[6]{0} reshape(gather) } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr start_indices = LiteralUtil::CreateR2( + Literal start_indices = LiteralUtil::CreateR2( {{2, 7}, {2, 1}, {1, 1}, {5, 1}, {2147483647, 1}, {1, 2}}); - RunTest(hlo_text, operand.get(), start_indices.get()); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, OutOfBoundsUnsignedIndex) { @@ -308,11 +303,11 @@ ENTRY main { ROOT result = s32[6]{0} reshape(gather) } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr start_indices = LiteralUtil::CreateR2( + Literal start_indices = LiteralUtil::CreateR2( {{2, 7}, {2, 1}, {1, 1}, {5, 1}, {2147483648u, 1}, {1, 2}}); - RunTest(hlo_text, operand.get(), start_indices.get()); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, NegativeIndex) { @@ -334,11 +329,11 @@ ENTRY main { ROOT result = s32[6]{0} reshape(gather) } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr start_indices = LiteralUtil::CreateR2( + Literal start_indices = LiteralUtil::CreateR2( {{2, -1}, {2, 1}, {1, 1}, {-500, 1}, {-2147483648, 1}, {1, 2}}); - RunTest(hlo_text, operand.get(), start_indices.get()); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, NegativeIndexIntoUnsignedOperand) { @@ -360,11 +355,11 @@ ENTRY main { ROOT result = u32[6]{0} reshape(gather) } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr start_indices = LiteralUtil::CreateR2( + Literal start_indices = LiteralUtil::CreateR2( {{2, -1}, {2, 1}, {1, 1}, {-500, 1}, {-2147483648, 1}, {1, 2}}); - RunTest(hlo_text, operand.get(), start_indices.get()); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, OneScalarIndex) { @@ -382,10 +377,10 @@ ENTRY main { slice_sizes={1,3,2} } )"; - std::unique_ptr operand = LiteralUtil::CreateR3( + Literal operand = LiteralUtil::CreateR3( {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}}); - std::unique_ptr start_indices = LiteralUtil::CreateR0(1); - RunTest(hlo_text, operand.get(), start_indices.get()); + Literal start_indices = LiteralUtil::CreateR0(1); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, ScalarResult) { @@ -403,9 +398,9 @@ ENTRY main { slice_sizes={1} } )"; - std::unique_ptr operand = LiteralUtil::CreateR1({1, 2, 3, 4}); - std::unique_ptr start_indices = LiteralUtil::CreateR0(1); - RunTest(hlo_text, operand.get(), start_indices.get()); + Literal operand = LiteralUtil::CreateR1({1, 2, 3, 4}); + Literal start_indices = LiteralUtil::CreateR0(1); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, ZeroSizedResult) { @@ -423,10 +418,10 @@ ENTRY main { slice_sizes={1, 3} } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr start_indices = LiteralUtil::CreateR1({}); - RunTest(hlo_text, operand.get(), start_indices.get()); + Literal start_indices = LiteralUtil::CreateR1({}); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, FusedTensorFlowGatherV2) { @@ -447,10 +442,10 @@ ENTRY main { ROOT result = s32[3,2]{1,0} add(gather, one_broadcasted) } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr start_indices = LiteralUtil::CreateR1({0, 2}); - RunTest(hlo_text, operand.get(), start_indices.get()); + Literal start_indices = LiteralUtil::CreateR1({0, 2}); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, FusedTensorFlowGatherMultipleBatchDims) { @@ -471,11 +466,10 @@ ENTRY main { ROOT result = s32[2,3,2]{2,1,0} add(gather, one_broadcasted) } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr start_indices = - LiteralUtil::CreateR2({{0, 2}, {2, 1}}); - RunTest(hlo_text, operand.get(), start_indices.get()); + Literal start_indices = LiteralUtil::CreateR2({{0, 2}, {2, 1}}); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, FusedTensorFlowGatherNdMultipleBatchDims) { @@ -496,11 +490,11 @@ ENTRY main { ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted) } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr start_indices = + Literal start_indices = LiteralUtil::CreateR3({{{0, 2}, {2, 1}}, {{1, 2}, {2, 0}}}); - RunTest(hlo_text, operand.get(), start_indices.get()); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, FusedTensorFlowGatherNd) { @@ -521,13 +515,12 @@ ENTRY main { ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted) } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // {{-4, 4}, {-5, 5}, {-6, 6}}, // {{-7, 7}, {-8, 8}, {-9, 9}}}); - std::unique_ptr start_indices = - LiteralUtil::CreateR2({{0, 0}, {1, 0}}); - RunTest(hlo_text, operand.get(), start_indices.get()); + Literal start_indices = LiteralUtil::CreateR2({{0, 0}, {1, 0}}); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, @@ -549,13 +542,12 @@ ENTRY main { ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted) } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // {{-4, 4}, {-5, 5}, {-6, 6}}, // {{-7, 7}, {-8, 8}, {-9, 9}}}); - std::unique_ptr start_indices = - LiteralUtil::CreateR2({{0, 0}, {1, 0}}); - RunTest(hlo_text, operand.get(), start_indices.get()); + Literal start_indices = LiteralUtil::CreateR2({{0, 0}, {1, 0}}); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, FusedDynamicSlice) { @@ -576,10 +568,10 @@ ENTRY main { ROOT result = s32[1,1]{1,0} add(gather, one_broadcasted) } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr start_indices = LiteralUtil::CreateR1({1, 1}); - RunTest(hlo_text, operand.get(), start_indices.get()); + Literal start_indices = LiteralUtil::CreateR1({1, 1}); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, FusedBatchDynamicSlice) { @@ -600,11 +592,10 @@ ENTRY main { ROOT result = s32[2,1,1]{2,1,0} add(gather, one_broadcasted) } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr start_indices = - LiteralUtil::CreateR2({{2, 1}, {1, 1}}); - RunTest(hlo_text, operand.get(), start_indices.get()); + Literal start_indices = LiteralUtil::CreateR2({{2, 1}, {1, 1}}); + RunTest(hlo_text, &operand, &start_indices); } class GatherClientLibraryTest : public ClientLibraryTestBase {}; @@ -641,10 +632,10 @@ XLA_TEST_F(GatherClientLibraryTest, DISABLED_ON_GPU(Basic)) { TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr operand_arg, client_->TransferToServer( - *LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}))); + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}))); TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr indices_arg, - client_->TransferToServer(*LiteralUtil::CreateR1({0, 2}))); + client_->TransferToServer(LiteralUtil::CreateR1({0, 2}))); TF_ASSERT_OK_AND_ASSIGN(std::vector devices, client_->GetDeviceHandles(1)); xla::ExecutionOptions execution_options = CreateDefaultExecutionOptions(); @@ -658,10 +649,9 @@ XLA_TEST_F(GatherClientLibraryTest, DISABLED_ON_GPU(Basic)) { TF_ASSERT_OK_AND_ASSIGN( std::vector> result_data, client_->ExecuteParallel(computation_instances)); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result_literal, + TF_ASSERT_OK_AND_ASSIGN(Literal result_literal, client_->Transfer(*(result_data[0]))); - LiteralTestUtil::ExpectR2Equal({{1, 2, 3}, {7, 8, 9}}, - *result_literal); + LiteralTestUtil::ExpectR2Equal({{1, 2, 3}, {7, 8, 9}}, result_literal); } } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/half_test.cc b/tensorflow/compiler/xla/tests/half_test.cc index 51450314b611b49c643fb6fd5b0c0d2e7205a2d2..1115e50fe3120b7dbd891f07dedcacefa5ecf3ea 100644 --- a/tensorflow/compiler/xla/tests/half_test.cc +++ b/tensorflow/compiler/xla/tests/half_test.cc @@ -126,9 +126,8 @@ INSTANTIATE_TEST_CASE_P(half, UnaryPredTest, ::testing::Values(UnaryPredTestParam{ [](half x) { return isfinite(x); }, &IsFinite})); -using BinaryBuildFuncTy = - std::function)>; +using BinaryBuildFuncTy = std::function)>; struct BinaryOpTestParam { std::function compute_func; diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc index 93ea144438afa2d6f2f6c696f54d1ab1073081b8..7ab2ecda58666acd7e9b8587d200a902b75822f3 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc @@ -22,6 +22,7 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/memory/memory.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/service/hlo_module.h" @@ -33,7 +34,6 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -44,7 +44,6 @@ namespace { using absl::optional; using absl::string_view; -using tensorflow::gtl::ArraySlice; constexpr char kInterpreter[] = "interpreter"; @@ -87,19 +86,25 @@ ProgramShape GetProgramShapeWithLayout(const HloModule& module) { } // namespace HloTestBase::HloTestBase(bool verifier_layout_sensitive, - bool allow_mixed_precision_in_hlo_verifier) + bool allow_mixed_precision_in_hlo_verifier, + std::function + instruction_can_change_layout_func) : HloTestBase(GetTestPlatform(), GetReferencePlatform(), verifier_layout_sensitive, - allow_mixed_precision_in_hlo_verifier) {} + allow_mixed_precision_in_hlo_verifier, + instruction_can_change_layout_func) {} HloTestBase::HloTestBase(se::Platform* test_platform, se::Platform* reference_platform, bool verifier_layout_sensitive, - bool allow_mixed_precision_in_hlo_verifier) + bool allow_mixed_precision_in_hlo_verifier, + std::function + instruction_can_change_layout_func) : test_runner_(test_platform), reference_runner_(reference_platform) { hlo_verifier_ = absl::make_unique( /*layout_sensitive=*/verifier_layout_sensitive, - /*allow_mixed_precision=*/allow_mixed_precision_in_hlo_verifier); + /*allow_mixed_precision=*/allow_mixed_precision_in_hlo_verifier, + instruction_can_change_layout_func); } std::unique_ptr HloTestBase::CreateNewModule(const string& name) { @@ -121,6 +126,14 @@ StatusOr HloTestBase::RunHloPass(HloPassInterface* hlo_pass, return status_or; } +/* static */ +PrecisionConfig HloTestBase::DefaultPrecisionConfig(int operands) { + PrecisionConfig precision_config; + precision_config.mutable_operand_precision()->Resize( + operands, PrecisionConfig::DEFAULT); + return precision_config; +} + DebugOptions HloTestBase::GetDebugOptionsForTest() { auto debug_options = legacy_flags::GetDebugOptionsFromFlags(); // TODO(b/38354253): Change tests to use Parameters instead of Constants. @@ -129,24 +142,21 @@ DebugOptions HloTestBase::GetDebugOptionsForTest() { return debug_options; } -StatusOr> HloTestBase::Execute( - std::unique_ptr module, - tensorflow::gtl::ArraySlice arguments) { +StatusOr HloTestBase::Execute(std::unique_ptr module, + absl::Span arguments) { return test_runner_.Execute(std::move(module), arguments); } -std::unique_ptr HloTestBase::ExecuteNoHloPasses( - std::unique_ptr module, - tensorflow::gtl::ArraySlice arguments) { +Literal HloTestBase::ExecuteNoHloPasses(std::unique_ptr module, + absl::Span arguments) { return test_runner_ .Execute(std::move(module), arguments, /*run_hlo_passes=*/false) .ValueOrDie(); } -std::unique_ptr HloTestBase::ExecuteAndTransfer( - std::unique_ptr module, - tensorflow::gtl::ArraySlice arguments) { +Literal HloTestBase::ExecuteAndTransfer(std::unique_ptr module, + absl::Span arguments) { return test_runner_.Execute(std::move(module), arguments).ValueOrDie(); } @@ -169,7 +179,8 @@ StatusOr> HloTestBase::MakeReferenceModule( } StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal( - std::unique_ptr module, const ArraySlice arguments, + std::unique_ptr module, + const absl::Span arguments, const optional& error, bool run_hlo_passes, const std::function& reference_preprocessor) { TF_RETURN_IF_ERROR(hlo_verifier_->Run(module.get()).status()); @@ -183,12 +194,13 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal( TF_ASSIGN_OR_RETURN(auto reference, reference_runner_.Execute(std::move(reference_module), arguments, run_hlo_passes)); - return LiteralTestUtil::NearOrEqual(/*expected=*/*reference, /*actual=*/*test, + return LiteralTestUtil::NearOrEqual(/*expected=*/reference, /*actual=*/test, error); } ::testing::AssertionResult HloTestBase::RunAndCompare( - std::unique_ptr module, const ArraySlice arguments, + std::unique_ptr module, + const absl::Span arguments, const optional& error, const std::function& reference_preprocessor) { auto result = @@ -201,7 +213,8 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal( } ::testing::AssertionResult HloTestBase::RunAndCompareNoHloPasses( - std::unique_ptr module, const ArraySlice arguments, + std::unique_ptr module, + const absl::Span arguments, const optional& error, const std::function& reference_preprocessor) { auto result = @@ -216,13 +229,12 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal( ::testing::AssertionResult HloTestBase::RunAndCompare( std::unique_ptr module, const optional& error, const std::function& reference_preprocessor) { - const auto& fake_arguments = - MakeFakeArguments(module.get()).ConsumeValueOrDie(); + auto fake_arguments = MakeFakeArguments(module.get()).ConsumeValueOrDie(); std::vector fake_argument_ptrs; absl::c_transform( fake_arguments, std::back_inserter(fake_argument_ptrs), - [](const std::unique_ptr& literal) { return literal.get(); }); + [](const Literal& literal) { return const_cast(&literal); }); return RunAndCompare(std::move(module), fake_argument_ptrs, error, reference_preprocessor); @@ -236,7 +248,7 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal( std::vector fake_argument_ptrs; absl::c_transform( fake_arguments, std::back_inserter(fake_argument_ptrs), - [](const std::unique_ptr& literal) { return literal.get(); }); + [](const Literal& literal) { return const_cast(&literal); }); return RunAndCompareNoHloPasses(std::move(module), fake_argument_ptrs, error, reference_preprocessor); @@ -270,7 +282,7 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal( std::vector fake_argument_ptrs; absl::c_transform( fake_arguments, std::back_inserter(fake_argument_ptrs), - [](const std::unique_ptr& literal) { return literal.get(); }); + [](const Literal& literal) { return const_cast(&literal); }); return test_runner_ .Execute(std::move(module_or_status.ValueOrDie()), fake_argument_ptrs, /*run_hlo_passes=*/true) diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h index 06bcc397417e0666c8c97f4286aba7d0b42a2d98..217428befa474448cf2dcbae2eb6cb5b0e61d44c 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_test_base.h @@ -21,6 +21,7 @@ limitations under the License. #include #include "absl/types/optional.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/hlo_module.h" @@ -32,7 +33,6 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/test.h" @@ -80,20 +80,26 @@ class HloTestBase : public ::testing::Test { static StatusOr RunHloPass(HloPassInterface* hlo_pass, HloModule* module); + static PrecisionConfig DefaultPrecisionConfig(int operands); + protected: // This uses the interpreter backend as the reference backend and // automatically finds another supported backend as the test backend. If the // interpreter is the only supported backend, it will be both the test backend // and the reference backend. HloTestBase(bool verifier_layout_sensitive = false, - bool allow_mixed_precision_in_hlo_verifier = true); + bool allow_mixed_precision_in_hlo_verifier = true, + std::function + instruction_can_change_layout_func = {}); // If your test doesn't use interpreter as the reference backend, you can use // this constructor. Note that your test target is responsible for linking in // both needed backends. HloTestBase(se::Platform* test_platform, se::Platform* reference_platform, bool verifier_layout_sensitive = false, - bool allow_mixed_precision_in_hlo_verifier = true); + bool allow_mixed_precision_in_hlo_verifier = true, + std::function + instruction_can_change_layout_func = {}); ~HloTestBase() override {} @@ -113,19 +119,16 @@ class HloTestBase : public ::testing::Test { } // Executes the given module and return the result as a Literal. - StatusOr> Execute( - std::unique_ptr module, - tensorflow::gtl::ArraySlice arguments); + StatusOr Execute(std::unique_ptr module, + absl::Span arguments); // Same as above, except the module will be executed without running any HLO // passes on it. - std::unique_ptr ExecuteNoHloPasses( - std::unique_ptr module, - tensorflow::gtl::ArraySlice arguments); + Literal ExecuteNoHloPasses(std::unique_ptr module, + absl::Span arguments); - std::unique_ptr ExecuteAndTransfer( - std::unique_ptr module, - tensorflow::gtl::ArraySlice arguments); + Literal ExecuteAndTransfer(std::unique_ptr module, + absl::Span arguments); // Executes the given hlo module on two backends and compares results. // @@ -140,7 +143,7 @@ class HloTestBase : public ::testing::Test { // modified. ::testing::AssertionResult RunAndCompare( std::unique_ptr module, - const tensorflow::gtl::ArraySlice arguments, + const absl::Span arguments, const absl::optional& error, const std::function& reference_preprocessor = nullptr) TF_MUST_USE_RESULT; @@ -149,7 +152,7 @@ class HloTestBase : public ::testing::Test { // optimization. ::testing::AssertionResult RunAndCompareNoHloPasses( std::unique_ptr module, - const tensorflow::gtl::ArraySlice arguments, + const absl::Span arguments, const absl::optional& error, const std::function& reference_preprocessor = nullptr) TF_MUST_USE_RESULT; @@ -261,7 +264,7 @@ class HloTestBase : public ::testing::Test { // error happens before the results are computed, returns the error status. StatusOr<::testing::AssertionResult> RunAndCompareInternal( std::unique_ptr module, - const tensorflow::gtl::ArraySlice arguments, + const absl::Span arguments, const absl::optional& error, bool run_hlo_passes, const std::function& reference_preprocessor); }; diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc b/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc index 8f86c528d0f346b0264948d592660911880f96d1..8bd0a729b77f3ec14204952cb0062103c823883e 100644 --- a/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc @@ -21,64 +21,68 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/test.h" namespace xla { -HloVerifiedTestBase::HloVerifiedTestBase(bool layout_sensitive, - bool allow_mixed_precision) - : HloTestBase( - /*verifier_layout_sensitive=*/layout_sensitive, - /*allow_mixed_precision_in_hlo_verifier=*/allow_mixed_precision) {} - -HloVerifiedTestBase::~HloVerifiedTestBase() { - // We can't call the ASSERT or EXPECT test macros in destructors, so we - // perform HLO verification in TearDown, and use the CHECK here to ensure - // users don't accidentally override the verification. - CHECK(tear_down_called_) - << "TearDown was never called; subclasses of HloVerifiedTestBase that " - << "override TearDown must call the superclass TearDown."; -} - -void HloVerifiedTestBase::TearDown() { - EXPECT_FALSE(tear_down_called_) - << "TearDown called more than once; it should be called exactly once."; - tear_down_called_ = true; - if (module_) { - VerifyModule(module_.get()); +Status VerifiedHloModule::Verify() { + if (computation_count() == 0) { + // The computation was never built. Nothing to verify. + return Status::OK(); } - for (int i = 0; i < modules_.size(); ++i) { - VerifyModule(modules_.at(i).get()); - } - HloTestBase::TearDown(); + return verifier_.Run(this).status(); } -void HloVerifiedTestBase::VerifyModule(HloModule* module) { - xla::StatusOr mutated = verifier().Run(module); - if (!mutated.ok()) { - ADD_FAILURE() << "HloVerifier failed: " << mutated.status(); - } else { - EXPECT_FALSE(mutated.ValueOrDie()) - << "HloVerifier should never mutate the HloModule"; +void VerifiedHloModule::VerifyOrAddFailure(const string& message) { + Status status = Verify(); + if (!status.ok()) { + ADD_FAILURE() << "HloVerifier failed on module " << name() + << (message.empty() ? "" : absl::StrCat(" (", message, ")")) + << ": " << status; } } +HloVerifiedTestBase::HloVerifiedTestBase(bool layout_sensitive, + bool allow_mixed_precision) + : HloTestBase( + /*verifier_layout_sensitive=*/layout_sensitive, + /*allow_mixed_precision_in_hlo_verifier=*/allow_mixed_precision), + verifier_layout_sensitive_(layout_sensitive), + allow_mixed_precision_in_hlo_verifier_(allow_mixed_precision) {} + HloModule& HloVerifiedTestBase::module() { if (!module_) { - module_ = HloTestBase::CreateNewModule(); + module_ = CreateNewVerifiedModule(TestName()); } return *module_; } HloModule* HloVerifiedTestBase::CreateNewModule(const string& name) { - modules_.emplace_back(HloTestBase::CreateNewModule()); + modules_.emplace_back(CreateNewVerifiedModule(name)); return modules_.back().get(); } void HloVerifiedTestBase::ParseAndVerifyModule(absl::string_view hlo_text, const HloModuleConfig& config) { CHECK(!module_) << "Called ParseModule when test already has a module."; - TF_ASSERT_OK_AND_ASSIGN(module_, ParseHloString(hlo_text, config)); - VerifyModule(module_.get()); + module_ = CreateNewVerifiedModule(TestName()); + TF_CHECK_OK(ParseHloString(hlo_text, module_.get())); + module_->VerifyOrAddFailure("after parsing"); } + +StatusOr> +HloVerifiedTestBase::ParseAndReturnVerifiedModule( + absl::string_view hlo_text, const HloModuleConfig& config) { + auto module = CreateNewVerifiedModule(TestName()); + TF_RETURN_IF_ERROR(ParseHloString(hlo_text, module.get())); + TF_RETURN_IF_ERROR(module->Verify()); + return std::move(module); +} + +std::unique_ptr HloVerifiedTestBase::CreateNewVerifiedModule( + const string& name) { + return absl::make_unique( + name, GetModuleConfigForTest(), verifier_layout_sensitive_, + allow_mixed_precision_in_hlo_verifier_); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base.h b/tensorflow/compiler/xla/tests/hlo_verified_test_base.h index cc6967feed47b74846814454d550b38a474f3a04..388a99bb36408665edbc20ade6c6a733d64db88d 100644 --- a/tensorflow/compiler/xla/tests/hlo_verified_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_verified_test_base.h @@ -20,53 +20,84 @@ limitations under the License. #include #include +#include "absl/base/macros.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" namespace xla { -// A base class for HLO tests that stores a default HloModule, and automatically -// performs verification on that module on tear-down. +// An HLO module derived class which verifies itself on destruction. This class +// is intended to be used in unit tests. Any verification errors are raised via +// ADD_FAILURE. +class VerifiedHloModule : public HloModule { + public: + VerifiedHloModule(const string& name, const HloModuleConfig& config, + bool verifier_layout_sensitive, + bool allow_mixed_precision_in_hlo_verifier) + : HloModule(name, config), + verifier_(verifier_layout_sensitive, + allow_mixed_precision_in_hlo_verifier) {} + + ~VerifiedHloModule() override { VerifyOrAddFailure("in destructor"); } + + // Verifies the module using HloVerifier and returns the status. + Status Verify(); + + // Verifies the module and flags any error with ADD_FAILURE. 'message' is + // included in the failure message. + void VerifyOrAddFailure(const string& message); + + private: + HloVerifier verifier_; +}; + +// A base class for HLO tests that stores a default VerifiedHloModule. class HloVerifiedTestBase : public HloTestBase { protected: - explicit HloVerifiedTestBase(bool layout_sensitive, - bool allow_mixed_precision); - ~HloVerifiedTestBase() override; + HloVerifiedTestBase(bool layout_sensitive = false, + bool allow_mixed_precision = false); // Constructs a default shape verifier. std::unique_ptr MakeShapeVerifier(); - // Performs verification on the default HloModule returned by module(). - // Automatically called by the testing framework for each test. - // - // REQUIRED: subclasses that override TearDown() must call this explicitly. - void TearDown() override; - // Returns the default HloModule, lazily creating it if necessary via // HloTestBase::CreateNewModule(). + ABSL_DEPRECATED("Use CreateNewVerifiedModule() instead.") HloModule& module(); + + ABSL_DEPRECATED("Use ParseAndReturnVerifiedModule() instead.") void ParseAndVerifyModule(absl::string_view hlo_text, const HloModuleConfig& config = HloModuleConfig()); + // Parses the given string and returns module as a VerifiedHloModule. + StatusOr> ParseAndReturnVerifiedModule( + absl::string_view hlo_text, + const HloModuleConfig& config = HloModuleConfig()); + // Creates a new module for a test, and stores it in modules_ so it can be // verified. Intentionally hides HloTestBase::CreateNewModule, to prevent // creation of unverified modules. + ABSL_DEPRECATED("Use CreateNewVerifiedModule() instead.") HloModule* CreateNewModule(const string& name = TestName()); - private: - void VerifyModule(HloModule* module); + // Creates and returns a verified HLO module with the given name. + std::unique_ptr CreateNewVerifiedModule( + const string& name = TestName()); + private: // It is confusing to store modules created by module() and CreateNewModule() // in different fields, but it allows us to migrate tests to // HloVerifiedTestBase more easily, so it's a win because we can verify more // modules. See b/80488902. // // Lazily populated. Access via module(). - std::unique_ptr module_; + std::unique_ptr module_; + // Populated by calls to CreateNewModule. - std::vector> modules_; + std::vector> modules_; - bool tear_down_called_ = false; + bool verifier_layout_sensitive_; + bool allow_mixed_precision_in_hlo_verifier_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base_test.cc b/tensorflow/compiler/xla/tests/hlo_verified_test_base_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..5c0263e811f94c90a69a460525ffa0c65127ebb5 --- /dev/null +++ b/tensorflow/compiler/xla/tests/hlo_verified_test_base_test.cc @@ -0,0 +1,158 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" + +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_verifier.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +// This class includes unit tests which are expected to fail because invalid HLO +// modules are intentionally built. Unfortunately, Tensorflow doesn't appear to +// include the necessary gunit parts to test this test machinery (needs the +// macro EXPECT_NONFATAL_FAILURE). The disabled tests can be run with the +// disabled tests enabled and failures can be manually compared against +// expectations. +class HloVerifiedTestBaseTest : public HloVerifiedTestBase {}; + +XLA_TEST_F(HloVerifiedTestBaseTest, NoModule) { + // Test shouldn't fail if no module is created at all. +} + +XLA_TEST_F(HloVerifiedTestBaseTest, GoodLazilyCreatedModule) { + // Use module() to lazily create an empty module, build it up, and verify no + // failures. + HloModule& hlo_module = module(); + auto builder = HloComputation::Builder(TestName()); + auto input = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); + builder.AddInstruction( + HloInstruction::CreateUnary(input->shape(), HloOpcode::kNegate, input)); + hlo_module.AddEntryComputation(builder.Build()); +} + +// This test is expected to fail. See test class comment. +XLA_TEST_F(HloVerifiedTestBaseTest, DISABLED_BadLazilyCreatedModule) { + // Use module() to lazily create an empty module and build up an invalid + // module. + HloModule& hlo_module = module(); + auto builder = HloComputation::Builder(TestName()); + auto input = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); + builder.AddInstruction( + HloInstruction::CreateUnary(input->shape(), HloOpcode::kNegate, input)); + hlo_module.AddEntryComputation(builder.Build()); + + *hlo_module.entry_computation()->root_instruction()->mutable_shape() = + ShapeUtil::MakeShape(PRED, {1, 2, 3}); +} + +XLA_TEST_F(HloVerifiedTestBaseTest, GoodCreateNewModule) { + // Call CreateNewModule and build up a valid module. + HloModule* module = CreateNewModule(); + auto builder = HloComputation::Builder(TestName()); + auto input = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); + builder.AddInstruction( + HloInstruction::CreateUnary(input->shape(), HloOpcode::kNegate, input)); + module->AddEntryComputation(builder.Build()); +} + +// This test is expected to fail. See test class comment. +XLA_TEST_F(HloVerifiedTestBaseTest, DISABLED_BadCreateNewModule) { + // Call CreateNewModule and build up a invalid module. + HloModule* module = CreateNewModule(); + auto builder = HloComputation::Builder(TestName()); + auto input = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); + builder.AddInstruction( + HloInstruction::CreateUnary(input->shape(), HloOpcode::kNegate, input)); + module->AddEntryComputation(builder.Build()); + + *module->entry_computation()->root_instruction()->mutable_shape() = + ShapeUtil::MakeShape(PRED, {1, 2, 3}); +} + +XLA_TEST_F(HloVerifiedTestBaseTest, ParseAndVerifyModuleGood) { + const char* const hlo_string = R"( +HloModule ParseAndVerifyModuleGood + +ENTRY entry { + x = f32[] parameter(0) + y = f32[] parameter(1) + ROOT add = f32[] add(x,y) +} +)"; + + ParseAndVerifyModule(hlo_string); + EXPECT_EQ(module().entry_computation()->instruction_count(), 3); +} + +XLA_TEST_F(HloVerifiedTestBaseTest, ParseAndReturnVerifiedModuleGood) { + const char* const hlo_string = R"( +HloModule ParseAndReturnVerifiedModuleGood + +ENTRY entry { + x = f32[] parameter(0) + y = f32[] parameter(1) + ROOT add = f32[] add(x,y) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + EXPECT_EQ(module->entry_computation()->instruction_count(), 3); +} + +XLA_TEST_F(HloVerifiedTestBaseTest, ParseAndReturnVerifiedModuleInvalidText) { + const char* const hlo_string = R"( +HloModule ParseAndReturnVerifiedModuleGood + +ENTRY entry { + x = f32[] parameter(0) + y = f32[] parameter(1) + ROOT add = f32[] add(x,y) +} + +RANDOM GARBAGE +)"; + + ASSERT_IS_NOT_OK(ParseAndReturnVerifiedModule(hlo_string).status()); +} + +// This test is expected to fail. See test class comment. +XLA_TEST_F(HloVerifiedTestBaseTest, DISABLED_ParseAndReturnVerifiedModuleBad) { + const char* const hlo_string = R"( +HloModule ParseAndReturnVerifiedModuleBad + +ENTRY entry { + x = f32[] parameter(0) + y = f32[] parameter(1) + ROOT add = f32[1234] add(x,y) +} +)"; + + ASSERT_IS_NOT_OK(ParseAndReturnVerifiedModule(hlo_string).status()); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/iota_test.cc b/tensorflow/compiler/xla/tests/iota_test.cc index 17ac95ae0198d98490b25f7f2edd32d1e0495803..310f3495922250d68aa463fcbb24ef0b04603d09 100644 --- a/tensorflow/compiler/xla/tests/iota_test.cc +++ b/tensorflow/compiler/xla/tests/iota_test.cc @@ -23,40 +23,95 @@ limitations under the License. namespace xla { namespace { -class IotaTest : public ClientLibraryTestBase { - public: - explicit IotaTest(se::Platform* platform = nullptr) - : ClientLibraryTestBase(platform) {} - template - std::vector GetExpected(const int64 num_elements) { - std::vector result(num_elements); - std::iota(result.begin(), result.end(), 0); - return result; +template +std::vector GetR1Expected(const int64 num_elements) { + std::vector result(num_elements); + std::iota(result.begin(), result.end(), 0); + return result; +} + +class IotaR1Test + : public ClientLibraryTestBase, + public ::testing::WithParamInterface> {}; + +TEST_P(IotaR1Test, DoIt) { + const auto& spec = GetParam(); + const auto element_type = std::get<0>(spec); + const int64 num_elements = std::get<1>(spec); + XlaBuilder builder(TestName() + "_" + PrimitiveType_Name(element_type)); + Iota(&builder, element_type, num_elements); + if (element_type == F32) { + ComputeAndCompareR1(&builder, GetR1Expected(num_elements), {}, + ErrorSpec{0.0001}); + } else if (element_type == U32) { + ComputeAndCompareR1(&builder, GetR1Expected(num_elements), + {}); + } else { + CHECK_EQ(element_type, S32); + ComputeAndCompareR1(&builder, GetR1Expected(num_elements), + {}); } -}; - -XLA_TEST_F(IotaTest, SimpleR1) { - for (int num_elements = 1; num_elements < 10000001; num_elements *= 10) { - { - XlaBuilder builder(TestName() + "_f32"); - IotaGen(&builder, F32, num_elements); - ComputeAndCompareR1(&builder, GetExpected(num_elements), {}, - ErrorSpec{0.0001}); - } - { - XlaBuilder builder(TestName() + "_u32"); - IotaGen(&builder, U32, num_elements); - ComputeAndCompareR1(&builder, GetExpected(num_elements), - {}); - } - { - XlaBuilder builder(TestName() + "_s32"); - IotaGen(&builder, S32, num_elements); - ComputeAndCompareR1(&builder, GetExpected(num_elements), - {}); - } +} + +INSTANTIATE_TEST_CASE_P(IotaR1TestInstantiation, IotaR1Test, + ::testing::Combine(::testing::Values(F32, U32, S32), + ::testing::Range(/*start=*/10, + /*end=*/10001, + /*step=*/10))); + +class IotaR2Test : public ClientLibraryTestBase, + public ::testing::WithParamInterface< + std::tuple> {}; + +TEST_P(IotaR2Test, DoIt) { + const auto& spec = GetParam(); + const auto element_type = std::get<0>(spec); + const int64 num_elements = std::get<1>(spec); + const int64 iota_dim = std::get<2>(spec); + XlaBuilder builder(TestName() + "_" + PrimitiveType_Name(element_type)); + std::vector dimensions = {42}; + dimensions.insert(dimensions.begin() + iota_dim, num_elements); + Iota(&builder, ShapeUtil::MakeShape(element_type, dimensions), iota_dim); + if (primitive_util::IsFloatingPointType(element_type)) { + ComputeAndCompare(&builder, {}, ErrorSpec{0.0001}); + } else { + ComputeAndCompare(&builder, {}); } } +INSTANTIATE_TEST_CASE_P(IotaR2TestInstantiation, IotaR2Test, + ::testing::Combine(::testing::Values(F32, S32), + ::testing::Range(/*start=*/10, + /*end=*/1001, + /*step=*/10), + ::testing::Values(0, 1))); + +class IotaR3Test : public ClientLibraryTestBase, + public ::testing::WithParamInterface< + std::tuple> {}; + +TEST_P(IotaR3Test, DoIt) { + const auto& spec = GetParam(); + const auto element_type = std::get<0>(spec); + const int64 num_elements = std::get<1>(spec); + const int64 iota_dim = std::get<2>(spec); + XlaBuilder builder(TestName() + "_" + PrimitiveType_Name(element_type)); + std::vector dimensions = {42, 19}; + dimensions.insert(dimensions.begin() + iota_dim, num_elements); + Iota(&builder, ShapeUtil::MakeShape(element_type, dimensions), iota_dim); + if (primitive_util::IsFloatingPointType(element_type)) { + ComputeAndCompare(&builder, {}, ErrorSpec{0.0001}); + } else { + ComputeAndCompare(&builder, {}); + } +} + +INSTANTIATE_TEST_CASE_P(IotaR3TestInstantiation, IotaR3Test, + ::testing::Combine(::testing::Values(F32, S32), + ::testing::Range(/*start=*/10, + /*end=*/1001, + /*step=*/10), + ::testing::Values(0, 1, 2))); + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/literal_test_util.cc b/tensorflow/compiler/xla/tests/literal_test_util.cc index a4e3a998fc48c364b8a61169167039d1c1ed28de..554eb24d44168caa7d7252015e3d99f2d567df9b 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util.cc +++ b/tensorflow/compiler/xla/tests/literal_test_util.cc @@ -15,9 +15,9 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/literal_comparison.h" #include "tensorflow/core/lib/io/path.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/test.h" namespace xla { @@ -35,8 +35,7 @@ void WriteLiteralToTempFile(const LiteralSlice& literal, const string& name) { int64 now_usec = tensorflow::Env::Default()->NowMicros(); string filename = tensorflow::io::JoinPath( tensorflow::testing::TmpDir(), - tensorflow::strings::Printf("tempfile-%s-%llx-%s", get_hostname().c_str(), - now_usec, name.c_str())); + absl::StrFormat("tempfile-%s-%x-%s", get_hostname(), now_usec, name)); TF_CHECK_OK(tensorflow::WriteBinaryProto(tensorflow::Env::Default(), filename, literal.ToProto())); LOG(ERROR) << "wrote to " << name << " file: " << filename; diff --git a/tensorflow/compiler/xla/tests/literal_test_util.h b/tensorflow/compiler/xla/tests/literal_test_util.h index 3dad91951e7322275cb0bf64e5e790c402d6cce9..43cca91f64b2c0fbfde5054a361cf0f95302c23d 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util.h +++ b/tensorflow/compiler/xla/tests/literal_test_util.h @@ -22,6 +22,7 @@ limitations under the License. #include #include "absl/types/optional.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/array4d.h" @@ -33,7 +34,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -62,7 +62,7 @@ class LiteralTestUtil { static void ExpectR0Equal(NativeT expected, const LiteralSlice& actual); template - static void ExpectR1Equal(tensorflow::gtl::ArraySlice expected, + static void ExpectR1Equal(absl::Span expected, const LiteralSlice& actual); template static void ExpectR2Equal( @@ -102,7 +102,7 @@ class LiteralTestUtil { const ErrorSpec& error); template - static void ExpectR1Near(tensorflow::gtl::ArraySlice expected, + static void ExpectR1Near(absl::Span expected, const LiteralSlice& actual, const ErrorSpec& error); template @@ -155,20 +155,20 @@ class LiteralTestUtil { template /* static */ void LiteralTestUtil::ExpectR0Equal(NativeT expected, const LiteralSlice& actual) { - EXPECT_TRUE(Equal(*LiteralUtil::CreateR0(expected), actual)); + EXPECT_TRUE(Equal(LiteralUtil::CreateR0(expected), actual)); } template /* static */ void LiteralTestUtil::ExpectR1Equal( - tensorflow::gtl::ArraySlice expected, const LiteralSlice& actual) { - EXPECT_TRUE(Equal(*LiteralUtil::CreateR1(expected), actual)); + absl::Span expected, const LiteralSlice& actual) { + EXPECT_TRUE(Equal(LiteralUtil::CreateR1(expected), actual)); } template /* static */ void LiteralTestUtil::ExpectR2Equal( std::initializer_list> expected, const LiteralSlice& actual) { - EXPECT_TRUE(Equal(*LiteralUtil::CreateR2(expected), actual)); + EXPECT_TRUE(Equal(LiteralUtil::CreateR2(expected), actual)); } template @@ -176,46 +176,46 @@ template std::initializer_list>> expected, const LiteralSlice& actual) { - EXPECT_TRUE(Equal(*LiteralUtil::CreateR3(expected), actual)); + EXPECT_TRUE(Equal(LiteralUtil::CreateR3(expected), actual)); } template /* static */ void LiteralTestUtil::ExpectR2EqualArray2D( const Array2D& expected, const LiteralSlice& actual) { - EXPECT_TRUE(Equal(*LiteralUtil::CreateR2FromArray2D(expected), actual)); + EXPECT_TRUE(Equal(LiteralUtil::CreateR2FromArray2D(expected), actual)); } template /* static */ void LiteralTestUtil::ExpectR3EqualArray3D( const Array3D& expected, const LiteralSlice& actual) { - EXPECT_TRUE(Equal(*LiteralUtil::CreateR3FromArray3D(expected), actual)); + EXPECT_TRUE(Equal(LiteralUtil::CreateR3FromArray3D(expected), actual)); } template /* static */ void LiteralTestUtil::ExpectR4EqualArray4D( const Array4D& expected, const LiteralSlice& actual) { - EXPECT_TRUE(Equal(*LiteralUtil::CreateR4FromArray4D(expected), actual)); + EXPECT_TRUE(Equal(LiteralUtil::CreateR4FromArray4D(expected), actual)); } template /* static */ void LiteralTestUtil::ExpectR0Near(NativeT expected, const LiteralSlice& actual, const ErrorSpec& error) { - EXPECT_TRUE(Near(*LiteralUtil::CreateR0(expected), actual, error)); + EXPECT_TRUE(Near(LiteralUtil::CreateR0(expected), actual, error)); } template /* static */ void LiteralTestUtil::ExpectR1Near( - tensorflow::gtl::ArraySlice expected, const LiteralSlice& actual, + absl::Span expected, const LiteralSlice& actual, const ErrorSpec& error) { - EXPECT_TRUE(Near(*LiteralUtil::CreateR1(expected), actual, error)); + EXPECT_TRUE(Near(LiteralUtil::CreateR1(expected), actual, error)); } template /* static */ void LiteralTestUtil::ExpectR2Near( std::initializer_list> expected, const LiteralSlice& actual, const ErrorSpec& error) { - EXPECT_TRUE(Near(*LiteralUtil::CreateR2(expected), actual, error)); + EXPECT_TRUE(Near(LiteralUtil::CreateR2(expected), actual, error)); } template @@ -223,7 +223,7 @@ template std::initializer_list>> expected, const LiteralSlice& actual, const ErrorSpec& error) { - EXPECT_TRUE(Near(*LiteralUtil::CreateR3(expected), actual, error)); + EXPECT_TRUE(Near(LiteralUtil::CreateR3(expected), actual, error)); } template @@ -232,28 +232,28 @@ template std::initializer_list>>> expected, const LiteralSlice& actual, const ErrorSpec& error) { - EXPECT_TRUE(Near(*LiteralUtil::CreateR4(expected), actual, error)); + EXPECT_TRUE(Near(LiteralUtil::CreateR4(expected), actual, error)); } template /* static */ void LiteralTestUtil::ExpectR2NearArray2D( const Array2D& expected, const LiteralSlice& actual, const ErrorSpec& error) { - EXPECT_TRUE(Near(*LiteralUtil::CreateR2FromArray2D(expected), actual, error)); + EXPECT_TRUE(Near(LiteralUtil::CreateR2FromArray2D(expected), actual, error)); } template /* static */ void LiteralTestUtil::ExpectR3NearArray3D( const Array3D& expected, const LiteralSlice& actual, const ErrorSpec& error) { - EXPECT_TRUE(Near(*LiteralUtil::CreateR3FromArray3D(expected), actual, error)); + EXPECT_TRUE(Near(LiteralUtil::CreateR3FromArray3D(expected), actual, error)); } template /* static */ void LiteralTestUtil::ExpectR4NearArray4D( const Array4D& expected, const LiteralSlice& actual, const ErrorSpec& error) { - EXPECT_TRUE(Near(*LiteralUtil::CreateR4FromArray4D(expected), actual, error)); + EXPECT_TRUE(Near(LiteralUtil::CreateR4FromArray4D(expected), actual, error)); } } // namespace xla diff --git a/tensorflow/compiler/xla/tests/literal_test_util_test.cc b/tensorflow/compiler/xla/tests/literal_test_util_test.cc index 4151bfae0332ffc706ba730d181c487eabab856f..b6f9b8156b51144e4f74d285b1e4111d098f13c2 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util_test.cc +++ b/tensorflow/compiler/xla/tests/literal_test_util_test.cc @@ -31,11 +31,11 @@ namespace xla { namespace { TEST(LiteralTestUtilTest, ComparesEqualTuplesEqual) { - std::unique_ptr literal = LiteralUtil::MakeTuple({ - LiteralUtil::CreateR0(42).get(), - LiteralUtil::CreateR0(64).get(), + Literal literal = LiteralUtil::MakeTupleFromSlices({ + LiteralUtil::CreateR0(42), + LiteralUtil::CreateR0(64), }); - EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *literal)); + EXPECT_TRUE(LiteralTestUtil::Equal(literal, literal)); } TEST(LiteralTestUtilTest, ComparesUnequalTuplesUnequal) { @@ -43,15 +43,15 @@ TEST(LiteralTestUtilTest, ComparesUnequalTuplesUnequal) { // un-fail an assertion failure. The CHECK-failure is death, so we can make a // death assertion. auto unequal_things_are_equal = [] { - std::unique_ptr lhs = LiteralUtil::MakeTuple({ - LiteralUtil::CreateR0(42).get(), - LiteralUtil::CreateR0(64).get(), + Literal lhs = LiteralUtil::MakeTupleFromSlices({ + LiteralUtil::CreateR0(42), + LiteralUtil::CreateR0(64), }); - std::unique_ptr rhs = LiteralUtil::MakeTuple({ - LiteralUtil::CreateR0(64).get(), - LiteralUtil::CreateR0(42).get(), + Literal rhs = LiteralUtil::MakeTupleFromSlices({ + LiteralUtil::CreateR0(64), + LiteralUtil::CreateR0(42), }); - CHECK(LiteralTestUtil::Equal(*lhs, *rhs)) << "LHS and RHS are unequal"; + CHECK(LiteralTestUtil::Equal(lhs, rhs)) << "LHS and RHS are unequal"; }; ASSERT_DEATH(unequal_things_are_equal(), "LHS and RHS are unequal"); } @@ -61,7 +61,7 @@ TEST(LiteralTestUtilTest, ExpectNearFailurePlacesResultsInTemporaryDirectory) { auto two = LiteralUtil::CreateR0(2); auto four = LiteralUtil::CreateR0(4); ErrorSpec error(0.001); - CHECK(LiteralTestUtil::Near(*two, *four, error)) << "two is not near four"; + CHECK(LiteralTestUtil::Near(two, four, error)) << "two is not near four"; }; tensorflow::Env* env = tensorflow::Env::Default(); @@ -86,14 +86,14 @@ TEST(LiteralTestUtilTest, ExpectNearFailurePlacesResultsInTemporaryDirectory) { LiteralProto literal_proto; TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), result, &literal_proto)); - std::unique_ptr literal = + Literal literal = Literal::CreateFromProto(literal_proto).ConsumeValueOrDie(); if (result.find("expected") != string::npos) { - EXPECT_EQ("2", literal->ToString()); + EXPECT_EQ("2", literal.ToString()); } else if (result.find("actual") != string::npos) { - EXPECT_EQ("4", literal->ToString()); + EXPECT_EQ("4", literal.ToString()); } else if (result.find("mismatches") != string::npos) { - EXPECT_EQ("true", literal->ToString()); + EXPECT_EQ("true", literal.ToString()); } else { FAIL() << "unknown file in temporary directory: " << result; } @@ -103,8 +103,7 @@ TEST(LiteralTestUtilTest, ExpectNearFailurePlacesResultsInTemporaryDirectory) { TEST(LiteralTestUtilTest, NotEqualHasValuesInMessage) { auto expected = LiteralUtil::CreateR1({1, 2, 3}); auto actual = LiteralUtil::CreateR1({4, 5, 6}); - ::testing::AssertionResult result = - LiteralTestUtil::Equal(*expected, *actual); + ::testing::AssertionResult result = LiteralTestUtil::Equal(expected, actual); EXPECT_THAT(result.message(), ::testing::HasSubstr("Expected literal:\n{1, 2, 3}")); EXPECT_THAT(result.message(), @@ -116,7 +115,7 @@ TEST(LiteralTestUtilTest, NearComparatorR1) { {0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8}); auto b = LiteralUtil::CreateR1( {0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8}); - EXPECT_TRUE(LiteralTestUtil::Near(*a, *b, ErrorSpec{0.0001})); + EXPECT_TRUE(LiteralTestUtil::Near(a, b, ErrorSpec{0.0001})); } TEST(LiteralTestUtilTest, NearComparatorR1Nan) { @@ -124,7 +123,7 @@ TEST(LiteralTestUtilTest, NearComparatorR1Nan) { {0.0, 0.1, 0.2, 0.3, NAN, 0.5, 0.6, 0.7, 0.8}); auto b = LiteralUtil::CreateR1( {0.0, 0.1, 0.2, 0.3, NAN, 0.5, 0.6, 0.7, 0.8}); - EXPECT_TRUE(LiteralTestUtil::Near(*a, *b, ErrorSpec{0.0001})); + EXPECT_TRUE(LiteralTestUtil::Near(a, b, ErrorSpec{0.0001})); } TEST(LiteralTestUtil, NearComparatorDifferentLengths) { @@ -132,8 +131,8 @@ TEST(LiteralTestUtil, NearComparatorDifferentLengths) { {0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8}); auto b = LiteralUtil::CreateR1({0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7}); - EXPECT_FALSE(LiteralTestUtil::Near(*a, *b, ErrorSpec{0.0001})); - EXPECT_FALSE(LiteralTestUtil::Near(*b, *a, ErrorSpec{0.0001})); + EXPECT_FALSE(LiteralTestUtil::Near(a, b, ErrorSpec{0.0001})); + EXPECT_FALSE(LiteralTestUtil::Near(b, a, ErrorSpec{0.0001})); } } // namespace diff --git a/tensorflow/compiler/xla/tests/llvm_compiler_test.cc b/tensorflow/compiler/xla/tests/llvm_compiler_test.cc index 8d658695576035cdc34a213847460dd80de5f67e..c622b295094e53e63d0ed692d428bc97724c787c 100644 --- a/tensorflow/compiler/xla/tests/llvm_compiler_test.cc +++ b/tensorflow/compiler/xla/tests/llvm_compiler_test.cc @@ -93,15 +93,16 @@ class LLVMCompilerTest : public ::testing::Test { std::unique_ptr hlo_module = CreateNewModule(); hlo_module->AddEntryComputation(builder.Build()); - std::vector> modules; - modules.push_back(hlo_module->Clone()); - modules.push_back(std::move(hlo_module)); + auto module_group = absl::make_unique("test_module_group"); + module_group->push_back(hlo_module->Clone()); + module_group->push_back(std::move(hlo_module)); std::vector> executors; executors.push_back({backend_->default_stream_executor()}); executors.push_back({backend_->default_stream_executor()}); - EXPECT_IS_OK(compiler->Compile(std::move(modules), std::move(executors), + EXPECT_IS_OK(compiler->Compile(std::move(module_group), + std::move(executors), /*device_allocator=*/nullptr)); } @@ -150,12 +151,12 @@ TEST_F(GpuCompilerTest, HooksTest) { TestCompilerHooks(&compiler); } -TEST_F(CpuCompilerTest, MultiModuleCompilation) { +TEST_F(CpuCompilerTest, CpuMultiModuleCompilation) { cpu::CpuCompiler compiler; TestMultiModuleCompilation(&compiler); } -TEST_F(GpuCompilerTest, MultModuleCompilation) { +TEST_F(GpuCompilerTest, NVPTXMultiModuleCompilation) { gpu::NVPTXCompiler compiler; TestMultiModuleCompilation(&compiler); } diff --git a/tensorflow/compiler/xla/tests/local_client_allocation_test.cc b/tensorflow/compiler/xla/tests/local_client_allocation_test.cc index 237a4a361e386e24c2897c42602eb60ca7234731..dbdd20daf0c3a54ed7b6e2a9d3fb73274d77474a 100644 --- a/tensorflow/compiler/xla/tests/local_client_allocation_test.cc +++ b/tensorflow/compiler/xla/tests/local_client_allocation_test.cc @@ -45,7 +45,7 @@ XLA_TEST_F(LocalClientAllocationTest, AddVectors) { TestAllocator* allocator = GetOrCreateAllocator(local_client_->platform()); auto x_array = - LiteralToShapedBuffer(*LiteralUtil::CreateR1({0.0f, 1.0f, 2.0f})); + LiteralToShapedBuffer(LiteralUtil::CreateR1({0.0f, 1.0f, 2.0f})); int64 allocation_count_before = allocator_->allocation_count(); @@ -58,7 +58,7 @@ XLA_TEST_F(LocalClientAllocationTest, AddVectors) { DefaultExecutableBuildOptions(), options); LiteralTestUtil::ExpectR1Near( - {2.0f, 4.0f, 6.0f}, *ShapedBufferToLiteral(*result), error_spec_); + {2.0f, 4.0f, 6.0f}, ShapedBufferToLiteral(*result), error_spec_); // At least one allocation should have been performed when executing the // computation. @@ -92,7 +92,7 @@ XLA_TEST_F(LocalClientAllocationTest, RunOnDevices) { computation, {}, ExecutableBuildOptions().set_device_ordinal(d), ExecutableRunOptions().set_device_ordinal(d).set_allocator(allocator)); LiteralTestUtil::ExpectR1Near( - {2.0f, 4.0f, 6.0f}, *ShapedBufferToLiteral(result), error_spec_); + {2.0f, 4.0f, 6.0f}, ShapedBufferToLiteral(result), error_spec_); // At least one allocation should have been performed when executing the // computation. diff --git a/tensorflow/compiler/xla/tests/local_client_execute_test.cc b/tensorflow/compiler/xla/tests/local_client_execute_test.cc index 1a823cf189b310c62c735419936544ea99fcfbaf..a99b43f4690b3063f76e2cda1e58c9b4ba9a1df4 100644 --- a/tensorflow/compiler/xla/tests/local_client_execute_test.cc +++ b/tensorflow/compiler/xla/tests/local_client_execute_test.cc @@ -58,7 +58,7 @@ XLA_TEST_F(LocalClientExecuteTest, Constant) { ScopedShapedBuffer result = ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {}); - LiteralTestUtil::ExpectR0Near(123.f, *ShapedBufferToLiteral(result), + LiteralTestUtil::ExpectR0Near(123.f, ShapedBufferToLiteral(result), error_spec_); } @@ -68,10 +68,10 @@ XLA_TEST_F(LocalClientExecuteTest, AddScalars) { auto y = ConstantR0(&builder, 123.0f); Add(x, y); - auto x_value = LiteralToShapedBuffer(*LiteralUtil::CreateR0(42.0f)); + auto x_value = LiteralToShapedBuffer(LiteralUtil::CreateR0(42.0f)); ScopedShapedBuffer result = ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {&x_value}); - LiteralTestUtil::ExpectR0Near(165.f, *ShapedBufferToLiteral(result), + LiteralTestUtil::ExpectR0Near(165.f, ShapedBufferToLiteral(result), error_spec_); } @@ -81,10 +81,10 @@ XLA_TEST_F(LocalClientExecuteTest, AddZeroElementVectors) { auto y = ConstantR1(&builder, {}); Add(x, y); - auto x_array = LiteralToShapedBuffer(*LiteralUtil::CreateR1({})); + auto x_array = LiteralToShapedBuffer(LiteralUtil::CreateR1({})); ScopedShapedBuffer result = ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {&x_array}); - LiteralTestUtil::ExpectR1Near({}, *ShapedBufferToLiteral(result), + LiteralTestUtil::ExpectR1Near({}, ShapedBufferToLiteral(result), error_spec_); } @@ -95,11 +95,11 @@ XLA_TEST_F(LocalClientExecuteTest, AddVectors) { Add(x, y); auto x_array = - LiteralToShapedBuffer(*LiteralUtil::CreateR1({0.0f, 1.0f, 2.0f})); + LiteralToShapedBuffer(LiteralUtil::CreateR1({0.0f, 1.0f, 2.0f})); ScopedShapedBuffer result = ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {&x_array}); LiteralTestUtil::ExpectR1Near( - {2.0f, 4.0f, 6.0f}, *ShapedBufferToLiteral(result), error_spec_); + {2.0f, 4.0f, 6.0f}, ShapedBufferToLiteral(result), error_spec_); } XLA_TEST_F(LocalClientExecuteTest, AddVectorsWithProfile) { @@ -109,14 +109,14 @@ XLA_TEST_F(LocalClientExecuteTest, AddVectorsWithProfile) { Add(x, y); auto x_array = - LiteralToShapedBuffer(*LiteralUtil::CreateR1({0.0f, 1.0f, 2.0f})); + LiteralToShapedBuffer(LiteralUtil::CreateR1({0.0f, 1.0f, 2.0f})); ExecutionProfile profile; ScopedShapedBuffer result = ExecuteLocallyOrDie( builder.Build().ValueOrDie(), {&x_array}, DefaultExecutableBuildOptions(), DefaultExecutableRunOptions().set_execution_profile(&profile)); LiteralTestUtil::ExpectR1Near( - {2.0f, 4.0f, 6.0f}, *ShapedBufferToLiteral(result), error_spec_); + {2.0f, 4.0f, 6.0f}, ShapedBufferToLiteral(result), error_spec_); EXPECT_GT(profile.compute_and_transfer_time_ns(), 0); } @@ -128,13 +128,13 @@ XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentInputLayouts) { auto computation = builder.Build().ConsumeValueOrDie(); // Create x as a col-major array. - auto x_array = LiteralToShapedBuffer(*LiteralUtil::CreateR2WithLayout( + auto x_array = LiteralToShapedBuffer(LiteralUtil::CreateR2WithLayout( {{1.0f, 2.0f}, {3.0f, 4.0f}}, LayoutUtil::MakeLayout({0, 1}))); EXPECT_TRUE(LayoutUtil::Equal(x_array.on_device_shape().layout(), LayoutUtil::MakeLayout({0, 1}))); // Create y as a row-major array. - auto y_array = LiteralToShapedBuffer(*LiteralUtil::CreateR2WithLayout( + auto y_array = LiteralToShapedBuffer(LiteralUtil::CreateR2WithLayout( {{10.0f, 20.0f}, {30.0f, 40.0f}}, LayoutUtil::MakeLayout({1, 0}))); EXPECT_TRUE(LayoutUtil::Equal(y_array.on_device_shape().layout(), LayoutUtil::MakeLayout({1, 0}))); @@ -142,15 +142,15 @@ XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentInputLayouts) { ScopedShapedBuffer result_colmaj = ExecuteLocallyOrDie(computation, {&x_array, &y_array}); LiteralTestUtil::ExpectR2Near({{11.0f, 22.0f}, {33.0f, 44.0f}}, - *ShapedBufferToLiteral(result_colmaj), + ShapedBufferToLiteral(result_colmaj), error_spec_); // Run with the parameter values in a different order. ScopedShapedBuffer result_param_swap = ExecuteLocallyOrDie(computation, {&y_array, &x_array}); - LiteralTestUtil::ExpectR2Near( - {{11.0f, 22.0f}, {33.0f, 44.0f}}, - *ShapedBufferToLiteral(result_param_swap), error_spec_); + LiteralTestUtil::ExpectR2Near({{11.0f, 22.0f}, {33.0f, 44.0f}}, + ShapedBufferToLiteral(result_param_swap), + error_spec_); } XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentOutputLayouts) { @@ -161,9 +161,9 @@ XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentOutputLayouts) { auto computation = builder.Build().ConsumeValueOrDie(); auto x_array = LiteralToShapedBuffer( - *LiteralUtil::CreateR2({{1.0f, 2.0f}, {3.0f, 4.0f}})); + LiteralUtil::CreateR2({{1.0f, 2.0f}, {3.0f, 4.0f}})); auto y_array = LiteralToShapedBuffer( - *LiteralUtil::CreateR2({{10.0f, 20.0f}, {30.0f, 40.0f}})); + LiteralUtil::CreateR2({{10.0f, 20.0f}, {30.0f, 40.0f}})); // Run with col-major result layout. ScopedShapedBuffer result_colmaj = ExecuteLocallyOrDie( @@ -174,7 +174,7 @@ XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentOutputLayouts) { EXPECT_TRUE(LayoutUtil::Equal(result_colmaj.on_device_shape().layout(), LayoutUtil::MakeLayout({0, 1}))); LiteralTestUtil::ExpectR2Near({{11.0f, 22.0f}, {33.0f, 44.0f}}, - *ShapedBufferToLiteral(result_colmaj), + ShapedBufferToLiteral(result_colmaj), error_spec_); // Run with row-major result layout. @@ -186,7 +186,7 @@ XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentOutputLayouts) { EXPECT_TRUE(LayoutUtil::Equal(result_rowmaj.on_device_shape().layout(), LayoutUtil::MakeLayout({1, 0}))); LiteralTestUtil::ExpectR2Near({{11.0f, 22.0f}, {33.0f, 44.0f}}, - *ShapedBufferToLiteral(result_rowmaj), + ShapedBufferToLiteral(result_rowmaj), error_spec_); } @@ -198,9 +198,9 @@ XLA_TEST_F(LocalClientExecuteTest, TupleResult) { auto computation = builder.Build().ConsumeValueOrDie(); auto x_array = LiteralToShapedBuffer( - *LiteralUtil::CreateR2({{1.0f, 2.0f}, {3.0f, 4.0f}})); + LiteralUtil::CreateR2({{1.0f, 2.0f}, {3.0f, 4.0f}})); auto y_array = LiteralToShapedBuffer( - *LiteralUtil::CreateR2({{10.0f, 20.0f}, {30.0f, 40.0f}})); + LiteralUtil::CreateR2({{10.0f, 20.0f}, {30.0f, 40.0f}})); ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&x_array, &y_array}); @@ -208,13 +208,13 @@ XLA_TEST_F(LocalClientExecuteTest, TupleResult) { EXPECT_TRUE(ShapeUtil::IsTuple(result.on_host_shape())); EXPECT_EQ(3, ShapeUtil::TupleElementCount(result.on_host_shape())); - std::unique_ptr result_literal = ShapedBufferToLiteral(result); + Literal result_literal = ShapedBufferToLiteral(result); LiteralTestUtil::ExpectR2Equal({{1.0f, 2.0f}, {3.0f, 4.0f}}, - LiteralSlice(*result_literal, {0})); + LiteralSlice(result_literal, {0})); LiteralTestUtil::ExpectR2Equal({{10.0f, 20.0f}, {30.0f, 40.0f}}, - LiteralSlice(*result_literal, {1})); + LiteralSlice(result_literal, {1})); LiteralTestUtil::ExpectR2Equal({{1.0f, 2.0f}, {3.0f, 4.0f}}, - LiteralSlice(*result_literal, {2})); + LiteralSlice(result_literal, {2})); } XLA_TEST_F(LocalClientExecuteTest, NestedTupleResult) { @@ -226,9 +226,9 @@ XLA_TEST_F(LocalClientExecuteTest, NestedTupleResult) { auto computation = builder.Build().ConsumeValueOrDie(); auto x_array = LiteralToShapedBuffer( - *LiteralUtil::CreateR2({{1.0f, 2.0f}, {3.0f, 4.0f}})); + LiteralUtil::CreateR2({{1.0f, 2.0f}, {3.0f, 4.0f}})); auto y_array = LiteralToShapedBuffer( - *LiteralUtil::CreateR2({{10.0f, 20.0f}, {30.0f, 40.0f}})); + LiteralUtil::CreateR2({{10.0f, 20.0f}, {30.0f, 40.0f}})); ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&x_array, &y_array}); @@ -236,15 +236,15 @@ XLA_TEST_F(LocalClientExecuteTest, NestedTupleResult) { EXPECT_TRUE(ShapeUtil::IsTuple(result.on_host_shape())); EXPECT_EQ(2, ShapeUtil::TupleElementCount(result.on_host_shape())); - std::unique_ptr result_literal = ShapedBufferToLiteral(result); + Literal result_literal = ShapedBufferToLiteral(result); LiteralTestUtil::ExpectR2Equal({{1.0f, 2.0f}, {3.0f, 4.0f}}, - LiteralSlice(*result_literal, {1})); + LiteralSlice(result_literal, {1})); LiteralTestUtil::ExpectR2Equal({{1.0f, 2.0f}, {3.0f, 4.0f}}, - LiteralSlice(*result_literal, {0, 0})); + LiteralSlice(result_literal, {0, 0})); LiteralTestUtil::ExpectR2Equal({{10.0f, 20.0f}, {30.0f, 40.0f}}, - LiteralSlice(*result_literal, {0, 1})); + LiteralSlice(result_literal, {0, 1})); LiteralTestUtil::ExpectR2Equal({{1.0f, 2.0f}, {3.0f, 4.0f}}, - LiteralSlice(*result_literal, {0, 2})); + LiteralSlice(result_literal, {0, 2})); } XLA_TEST_F(LocalClientExecuteTest, TupleResultWithLayout) { @@ -255,7 +255,7 @@ XLA_TEST_F(LocalClientExecuteTest, TupleResultWithLayout) { Tuple(&builder, {x, y}); auto array = LiteralToShapedBuffer( - *LiteralUtil::CreateR2({{1.0f, 2.0f}, {3.0f, 4.0f}})); + LiteralUtil::CreateR2({{1.0f, 2.0f}, {3.0f, 4.0f}})); ExecutableBuildOptions options = DefaultExecutableBuildOptions(); Shape shape_with_layout = ShapeUtil::MakeTupleShape( @@ -268,11 +268,11 @@ XLA_TEST_F(LocalClientExecuteTest, TupleResultWithLayout) { ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {&array, &array}, options, DefaultExecutableRunOptions()); - std::unique_ptr result_literal = ShapedBufferToLiteral(result); + Literal result_literal = ShapedBufferToLiteral(result); LiteralTestUtil::ExpectR2Equal({{1.0f, 2.0f}, {3.0f, 4.0f}}, - LiteralSlice(*result_literal, {0})); + LiteralSlice(result_literal, {0})); LiteralTestUtil::ExpectR2Equal({{1.0f, 2.0f}, {3.0f, 4.0f}}, - LiteralSlice(*result_literal, {1})); + LiteralSlice(result_literal, {1})); } XLA_TEST_F(LocalClientExecuteTest, TupleArguments) { @@ -298,15 +298,15 @@ XLA_TEST_F(LocalClientExecuteTest, TupleArguments) { Tuple(&builder, {array_sum, vector_diff}); auto computation = builder.Build().ConsumeValueOrDie(); - auto x_literal = LiteralUtil::MakeTuple( - {LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}).get(), - LiteralUtil::CreateR1({42.0, 75.0, 123.0}).get()}); - auto y_literal = LiteralUtil::MakeTuple( - {LiteralUtil::CreateR1({2.0, 4.0, 6.0}).get(), - LiteralUtil::CreateR2({{55.0, 44.0}, {33.0, 22.0}}).get()}); + auto x_literal = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}), + LiteralUtil::CreateR1({42.0, 75.0, 123.0})}); + auto y_literal = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR1({2.0, 4.0, 6.0}), + LiteralUtil::CreateR2({{55.0, 44.0}, {33.0, 22.0}})}); - auto x_buffer = LiteralToShapedBuffer(*x_literal); - auto y_buffer = LiteralToShapedBuffer(*y_literal); + auto x_buffer = LiteralToShapedBuffer(x_literal); + auto y_buffer = LiteralToShapedBuffer(y_literal); ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&x_buffer, &y_buffer}); @@ -314,11 +314,11 @@ XLA_TEST_F(LocalClientExecuteTest, TupleArguments) { EXPECT_TRUE(ShapeUtil::IsTuple(result.on_host_shape())); EXPECT_EQ(2, ShapeUtil::TupleElementCount(result.on_host_shape())); - std::unique_ptr result_literal = ShapedBufferToLiteral(result); + Literal result_literal = ShapedBufferToLiteral(result); LiteralTestUtil::ExpectR2Equal({{56.0f, 46.0f}, {36.0f, 26.0f}}, - LiteralSlice(*result_literal, {0})); + LiteralSlice(result_literal, {0})); LiteralTestUtil::ExpectR1Equal({40.0f, 71.0f, 117.0f}, - LiteralSlice(*result_literal, {1})); + LiteralSlice(result_literal, {1})); } XLA_TEST_F(LocalClientExecuteTest, NestedTupleArgument) { @@ -344,21 +344,20 @@ XLA_TEST_F(LocalClientExecuteTest, NestedTupleArgument) { Tuple(&builder, {negate_array, vector_sum}); auto computation = builder.Build().ConsumeValueOrDie(); - auto arg_literal = LiteralUtil::MakeTuple( - {LiteralUtil::MakeTuple( - {LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}).get(), - LiteralUtil::CreateR1({42.0, 75.0, 123.0}).get()}) - .get(), - LiteralUtil::CreateR1({222.0, -2.0, 10.0}).get()}); - auto arg_buffer = LiteralToShapedBuffer(*arg_literal); + auto arg_literal = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}), + LiteralUtil::CreateR1({42.0, 75.0, 123.0})}), + LiteralUtil::CreateR1({222.0, -2.0, 10.0})}); + auto arg_buffer = LiteralToShapedBuffer(arg_literal); ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&arg_buffer}); - std::unique_ptr result_literal = ShapedBufferToLiteral(result); + Literal result_literal = ShapedBufferToLiteral(result); LiteralTestUtil::ExpectR2Equal({{-1.0, -2.0}, {-3.0, -4}}, - LiteralSlice(*result_literal, {0})); + LiteralSlice(result_literal, {0})); LiteralTestUtil::ExpectR1Equal({264.0, 73.0, 133.0}, - LiteralSlice(*result_literal, {1})); + LiteralSlice(result_literal, {1})); } XLA_TEST_F(LocalClientExecuteTest, PassingTupleResultBackIntoComputation) { @@ -377,24 +376,24 @@ XLA_TEST_F(LocalClientExecuteTest, PassingTupleResultBackIntoComputation) { Tuple(&builder, {Neg(element_0), Add(element_1, element_1)}); auto computation = builder.Build().ConsumeValueOrDie(); - auto arg_literal = LiteralUtil::MakeTuple( - {LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}).get(), - LiteralUtil::CreateR2({{11.0, 3.0}, {4.0, 5.0}}).get()}); - auto arg_buffer = LiteralToShapedBuffer(*arg_literal); + auto arg_literal = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}), + LiteralUtil::CreateR2({{11.0, 3.0}, {4.0, 5.0}})}); + auto arg_buffer = LiteralToShapedBuffer(arg_literal); ScopedShapedBuffer result_0 = ExecuteLocallyOrDie(computation, {&arg_buffer}); - std::unique_ptr result_0_literal = ShapedBufferToLiteral(result_0); + Literal result_0_literal = ShapedBufferToLiteral(result_0); LiteralTestUtil::ExpectR2Equal({{-1.0, -2.0}, {-3.0, -4.0}}, - LiteralSlice(*result_0_literal, {0})); + LiteralSlice(result_0_literal, {0})); LiteralTestUtil::ExpectR2Equal({{22.0, 6.0}, {8.0, 10}}, - LiteralSlice(*result_0_literal, {1})); + LiteralSlice(result_0_literal, {1})); ScopedShapedBuffer result_1 = ExecuteLocallyOrDie(computation, {&result_0}); - std::unique_ptr result_1_literal = ShapedBufferToLiteral(result_1); + Literal result_1_literal = ShapedBufferToLiteral(result_1); LiteralTestUtil::ExpectR2Equal({{1.0, 2.0}, {3.0, 4.0}}, - LiteralSlice(*result_1_literal, {0})); + LiteralSlice(result_1_literal, {0})); LiteralTestUtil::ExpectR2Equal({{44.0, 12.0}, {16.0, 20}}, - LiteralSlice(*result_1_literal, {1})); + LiteralSlice(result_1_literal, {1})); } XLA_TEST_F(LocalClientExecuteTest, LargeTuple) { @@ -427,20 +426,19 @@ XLA_TEST_F(LocalClientExecuteTest, LargeTuple) { // Feed in a tuple where each two-element vector element is {tuple_index, // -tuple_index}. - std::vector> arg_elements; + std::vector arg_elements; for (int i = 0; i < kElementCount; ++i) { arg_elements.push_back(LiteralUtil::CreateR1({1.0f * i, -1.0f * i})); } - std::unique_ptr arg_literal = - LiteralUtil::MakeTupleOwned(std::move(arg_elements)); - auto arg_buffer = LiteralToShapedBuffer(*arg_literal); + Literal arg_literal = LiteralUtil::MakeTupleOwned(std::move(arg_elements)); + auto arg_buffer = LiteralToShapedBuffer(arg_literal); ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&arg_buffer}); - std::unique_ptr result_literal = ShapedBufferToLiteral(result); + Literal result_literal = ShapedBufferToLiteral(result); for (int i = 0; i < kElementCount; ++i) { LiteralTestUtil::ExpectR1Near( - {2.0f * i, 0.0f}, LiteralSlice(*result_literal, {i}), error_spec_); + {2.0f * i, 0.0f}, LiteralSlice(result_literal, {i}), error_spec_); } } @@ -476,9 +474,9 @@ XLA_TEST_F(LocalClientExecuteTest, LargeNestedTuple) { auto computation = builder.Build().ConsumeValueOrDie(); // Construct the argument to pass to the computation. - std::vector> outer_tuple_elements; + std::vector outer_tuple_elements; for (int i = 0; i < kFanout; ++i) { - std::vector> inner_tuple_elements; + std::vector inner_tuple_elements; for (int j = 0; j < kFanout; ++j) { inner_tuple_elements.push_back(LiteralUtil::CreateR0(i + j)); } @@ -487,16 +485,16 @@ XLA_TEST_F(LocalClientExecuteTest, LargeNestedTuple) { } auto arg_literal = LiteralUtil::MakeTupleOwned(std::move(outer_tuple_elements)); - auto arg_buffer = LiteralToShapedBuffer(*arg_literal); + auto arg_buffer = LiteralToShapedBuffer(arg_literal); ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&arg_buffer}); - std::unique_ptr result_literal = ShapedBufferToLiteral(result); + Literal result_literal = ShapedBufferToLiteral(result); for (int i = 0; i < kFanout; ++i) { for (int j = 0; j < kFanout; ++j) { - LiteralTestUtil::ExpectR0Near( - i + j + i * kFanout + j, LiteralSlice(*result_literal, {i, j}), - error_spec_); + LiteralTestUtil::ExpectR0Near(i + j + i * kFanout + j, + LiteralSlice(result_literal, {i, j}), + error_spec_); } } } @@ -525,23 +523,23 @@ XLA_TEST_F(LocalClientExecuteTest, DeepTuple) { auto computation = builder.Build().ConsumeValueOrDie(); // Construct the argument to pass to the computation. - std::unique_ptr arg_literal = LiteralUtil::CreateR0(123.0); + Literal arg_literal = LiteralUtil::CreateR0(123.0); for (int i = 0; i < kTupleDepth; ++i) { - std::vector> arg_vector; + std::vector arg_vector; arg_vector.push_back(std::move(arg_literal)); arg_literal = LiteralUtil::MakeTupleOwned(std::move(arg_vector)); } - auto arg_buffer = LiteralToShapedBuffer(*arg_literal); + auto arg_buffer = LiteralToShapedBuffer(arg_literal); ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&arg_buffer}); - std::unique_ptr result_literal = ShapedBufferToLiteral(result); + Literal result_literal = ShapedBufferToLiteral(result); ShapeIndex index; for (int i = 0; i < kTupleDepth; ++i) { index.push_back(0); } LiteralTestUtil::ExpectR0Equal(165.0, - LiteralSlice(*result_literal, index)); + LiteralSlice(result_literal, index)); } XLA_TEST_F(LocalClientExecuteTest, InvalidNumberOfArguments) { @@ -552,7 +550,7 @@ XLA_TEST_F(LocalClientExecuteTest, InvalidNumberOfArguments) { Add(x, y); auto x_array = - LiteralToShapedBuffer(*LiteralUtil::CreateR1({1.0f, 2.0f, 3.0f})); + LiteralToShapedBuffer(LiteralUtil::CreateR1({1.0f, 2.0f, 3.0f})); auto execute_status = ExecuteLocally(builder.Build().ValueOrDie(), {&x_array}); @@ -568,7 +566,7 @@ XLA_TEST_F(LocalClientExecuteTest, IncorrectArgumentShape) { Neg(x); auto x_array = LiteralToShapedBuffer( - *LiteralUtil::CreateR2({{0.0f, 1.0f}, {2.0f, 3.0f}})); + LiteralUtil::CreateR2({{0.0f, 1.0f}, {2.0f, 3.0f}})); auto execute_status = ExecuteLocally(builder.Build().ValueOrDie(), {&x_array}); @@ -585,7 +583,7 @@ XLA_TEST_F(LocalClientExecuteTest, InvalidResultLayout) { Neg(x); auto x_array = LiteralToShapedBuffer( - *LiteralUtil::CreateR2({{0.0f, 1.0f}, {2.0f, 3.0f}})); + LiteralUtil::CreateR2({{0.0f, 1.0f}, {2.0f, 3.0f}})); auto execute_status = ExecuteLocally( builder.Build().ValueOrDie(), {&x_array}, DefaultExecutableBuildOptions().set_result_layout( @@ -622,7 +620,7 @@ XLA_TEST_F(LocalClientExecuteTest, RunOnAllDeviceOrdinals) { DefaultExecutableRunOptions().set_device_ordinal(d)); EXPECT_EQ(d, result.device_ordinal()); LiteralTestUtil::ExpectR0Equal(42.0f, - *ShapedBufferToLiteral(result)); + ShapedBufferToLiteral(result)); } } } @@ -666,8 +664,7 @@ XLA_TEST_F(LocalClientExecuteTest, RunOnStream) { // As a check to verify that the computation ran of the device associated // with the stream. This is a weak check, but stronger verification is hard. EXPECT_EQ(d, result.device_ordinal()); - LiteralTestUtil::ExpectR0Equal(42.0f, - *ShapedBufferToLiteral(result)); + LiteralTestUtil::ExpectR0Equal(42.0f, ShapedBufferToLiteral(result)); } } @@ -745,11 +742,11 @@ XLA_TEST_F(LocalClientExecuteTest, SelectBetweenTuples) { ScopedShapedBuffer result = ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {}); - std::unique_ptr tuple_literal = ShapedBufferToLiteral(result); + Literal tuple_literal = ShapedBufferToLiteral(result); LiteralTestUtil::ExpectR1Equal({2.0f, 4.0f, 6.0f}, - LiteralSlice(*tuple_literal, {0})); + LiteralSlice(tuple_literal, {0})); LiteralTestUtil::ExpectR1Equal({1.0f, 2.0f, 3.0f}, - LiteralSlice(*tuple_literal, {1})); + LiteralSlice(tuple_literal, {1})); } XLA_TEST_F(LocalClientExecuteTest, CompileExecutable) { @@ -768,7 +765,7 @@ XLA_TEST_F(LocalClientExecuteTest, CompileExecutable) { executable_status.ConsumeValueOrDie(); auto x_array = - LiteralToShapedBuffer(*LiteralUtil::CreateR1({0.0f, 1.0f, 2.0f})); + LiteralToShapedBuffer(LiteralUtil::CreateR1({0.0f, 1.0f, 2.0f})); ScopedShapedBuffer result = executable->Run({&x_array}, DefaultExecutableRunOptions()) .ConsumeValueOrDie(); @@ -778,7 +775,7 @@ XLA_TEST_F(LocalClientExecuteTest, CompileExecutable) { ->BlockHostUntilDone()); LiteralTestUtil::ExpectR1Near( - {2.0f, 4.0f, 6.0f}, *ShapedBufferToLiteral(result), error_spec_); + {2.0f, 4.0f, 6.0f}, ShapedBufferToLiteral(result), error_spec_); } XLA_TEST_F(LocalClientExecuteTest, ShapeBufferToLiteralConversion) { @@ -792,33 +789,33 @@ XLA_TEST_F(LocalClientExecuteTest, ShapeBufferToLiteralConversion) { TF_ASSERT_OK_AND_ASSIGN( auto transferred_literal, local_client_->ShapedBufferToLiteral(shaped_buffer)); - EXPECT_EQ(literal, *transferred_literal); + EXPECT_EQ(literal, transferred_literal); }; // Array shapes. - test_to_device_and_back(*LiteralUtil::CreateR0(42.0)); - test_to_device_and_back(*LiteralUtil::CreateR0(true)); - test_to_device_and_back(*LiteralUtil::CreateR1({1.0, 42.0, 744.4})); + test_to_device_and_back(LiteralUtil::CreateR0(42.0)); + test_to_device_and_back(LiteralUtil::CreateR0(true)); + test_to_device_and_back(LiteralUtil::CreateR1({1.0, 42.0, 744.4})); test_to_device_and_back( - *LiteralUtil::CreateR2({{1.0, 2.0, 3.0}, {44.0, 0.1, -3}})); - test_to_device_and_back(*LiteralUtil::CreateR2({{2, 1}, {4444, 56}})); + LiteralUtil::CreateR2({{1.0, 2.0, 3.0}, {44.0, 0.1, -3}})); + test_to_device_and_back(LiteralUtil::CreateR2({{2, 1}, {4444, 56}})); // Null shape (empty tuple). - test_to_device_and_back(*LiteralUtil::MakeTuple({})); + test_to_device_and_back(LiteralUtil::MakeTuple({})); // Non-nested tuples. - test_to_device_and_back( - *LiteralUtil::MakeTuple({LiteralUtil::CreateR0(12223.0).get()})); - test_to_device_and_back( - *LiteralUtil::MakeTuple({LiteralUtil::CreateR1({1.0, -42.0}).get(), - LiteralUtil::CreateR0(123456.0).get()})); + test_to_device_and_back(LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR0(12223.0)})); + test_to_device_and_back(LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR1({1.0, -42.0}), + LiteralUtil::CreateR0(123456.0)})); // Nested tuple. - test_to_device_and_back(*LiteralUtil::MakeTuple( - {LiteralUtil::MakeTuple({LiteralUtil::CreateR1({1.0, -42.0}).get(), - LiteralUtil::CreateR0(123456.0).get()}) - .get(), - LiteralUtil::CreateR0(false).get()})); + test_to_device_and_back(LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR1({1.0, -42.0}), + LiteralUtil::CreateR0(123456.0)}), + LiteralUtil::CreateR0(false)})); } XLA_TEST_F(LocalClientExecuteTest, ShapeBufferToLiteralConversion64bit) { @@ -832,17 +829,17 @@ XLA_TEST_F(LocalClientExecuteTest, ShapeBufferToLiteralConversion64bit) { TF_ASSERT_OK_AND_ASSIGN( auto transferred_literal, local_client_->ShapedBufferToLiteral(shaped_buffer)); - EXPECT_EQ(literal, *transferred_literal); + EXPECT_EQ(literal, transferred_literal); }; test_to_device_and_back( - *LiteralUtil::CreateR2({{1.0, 2.0, 3.0}, {44.0, 0.1, -3}})); - test_to_device_and_back(*LiteralUtil::CreateR2({{2, 1}, {4444, 56}})); + LiteralUtil::CreateR2({{1.0, 2.0, 3.0}, {44.0, 0.1, -3}})); + test_to_device_and_back(LiteralUtil::CreateR2({{2, 1}, {4444, 56}})); test_to_device_and_back( - *LiteralUtil::CreateR2({{20000000000ULL, 1}, {4444, 56}})); - test_to_device_and_back(*LiteralUtil::MakeTuple( - {LiteralUtil::CreateR1({1.0, -42.0}).get(), - LiteralUtil::CreateR0(123456789000LL).get()})); + LiteralUtil::CreateR2({{20000000000ULL, 1}, {4444, 56}})); + test_to_device_and_back(LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR1({1.0, -42.0}), + LiteralUtil::CreateR0(123456789000LL)})); } XLA_TEST_F(LocalClientExecuteTest, InfeedTest) { @@ -852,7 +849,7 @@ XLA_TEST_F(LocalClientExecuteTest, InfeedTest) { auto constant = ConstantR1(&builder, {1.0f, 2.0f, 3.0f}); Add(in, constant); - std::unique_ptr result; + Literal result; std::unique_ptr thread( tensorflow::Env::Default()->StartThread( tensorflow::ThreadOptions(), "execute_thread", [&] { @@ -861,13 +858,13 @@ XLA_TEST_F(LocalClientExecuteTest, InfeedTest) { })); ASSERT_IS_OK(local_client_->TransferToInfeedLocal( - *LiteralUtil::CreateR1({-5.0, 123.0, 42.0}), + LiteralUtil::CreateR1({-5.0, 123.0, 42.0}), local_client_->default_device_ordinal())); // Join the thread. thread.reset(); - LiteralTestUtil::ExpectR1Equal({-4.0, 125.0, 45.0}, *result); + LiteralTestUtil::ExpectR1Equal({-4.0, 125.0, 45.0}, result); } XLA_TEST_F(LocalClientExecuteTest, InfeedOutfeedTest) { @@ -884,14 +881,14 @@ XLA_TEST_F(LocalClientExecuteTest, InfeedOutfeedTest) { [&] { ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {}); })); ASSERT_IS_OK(local_client_->TransferToInfeedLocal( - *LiteralUtil::CreateR1({-5.0, 123.0, 42.0}), + LiteralUtil::CreateR1({-5.0, 123.0, 42.0}), local_client_->default_device_ordinal())); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, + TF_ASSERT_OK_AND_ASSIGN(Literal result, local_client_->TransferFromOutfeedLocal( shape, local_client_->default_device_ordinal())); - LiteralTestUtil::ExpectR1Equal({-4.0, 125.0, 45.0}, *result); + LiteralTestUtil::ExpectR1Equal({-4.0, 125.0, 45.0}, result); } // Benchmark that measures the overhead of the LocalClient API when running a @@ -922,8 +919,8 @@ void BM_LocalClientOverhead(int num_iters) { auto literal = LiteralUtil::CreateR2({{0, 0, 0}, {0, 0, 0}}); auto stream = client->mutable_backend()->BorrowStream(device_ordinal).ValueOrDie(); - ASSERT_IS_OK(transfer_manager->TransferLiteralToDevice(stream.get(), *literal, - buffer)); + ASSERT_IS_OK( + transfer_manager->TransferLiteralToDevice(stream.get(), literal, buffer)); const int kWarmups = 2; diff --git a/tensorflow/compiler/xla/tests/local_client_test_base.cc b/tensorflow/compiler/xla/tests/local_client_test_base.cc index 948b60061e2f47c73c7c7a2d6cbc65baf1b4411c..f90ef22d2d549f451f8af231aea834e9f097b12a 100644 --- a/tensorflow/compiler/xla/tests/local_client_test_base.cc +++ b/tensorflow/compiler/xla/tests/local_client_test_base.cc @@ -136,7 +136,7 @@ ScopedShapedBuffer LocalClientTestBase::LiteralToShapedBuffer( .ConsumeValueOrDie(); } -std::unique_ptr LocalClientTestBase::ShapedBufferToLiteral( +Literal LocalClientTestBase::ShapedBufferToLiteral( const ShapedBuffer& shaped_buffer) { return local_client_->ShapedBufferToLiteral(shaped_buffer) .ConsumeValueOrDie(); @@ -156,7 +156,7 @@ ExecutableRunOptions LocalClientTestBase::DefaultExecutableRunOptions() const { ScopedShapedBuffer LocalClientTestBase::ExecuteLocallyOrDie( const XlaComputation& computation, - tensorflow::gtl::ArraySlice arguments) { + absl::Span arguments) { return ExecuteLocally(computation, arguments, DefaultExecutableBuildOptions(), DefaultExecutableRunOptions()) .ConsumeValueOrDie(); @@ -164,7 +164,7 @@ ScopedShapedBuffer LocalClientTestBase::ExecuteLocallyOrDie( ScopedShapedBuffer LocalClientTestBase::ExecuteLocallyOrDie( const XlaComputation& computation, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, const ExecutableBuildOptions& build_options, const ExecutableRunOptions& run_options) { return ExecuteLocally(computation, arguments, build_options, run_options) @@ -173,14 +173,14 @@ ScopedShapedBuffer LocalClientTestBase::ExecuteLocallyOrDie( StatusOr LocalClientTestBase::ExecuteLocally( const XlaComputation& computation, - tensorflow::gtl::ArraySlice arguments) { + absl::Span arguments) { return ExecuteLocally(computation, arguments, DefaultExecutableBuildOptions(), DefaultExecutableRunOptions()); } StatusOr LocalClientTestBase::ExecuteLocally( const XlaComputation& computation, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, const ExecutableBuildOptions& build_options, const ExecutableRunOptions& run_options) { std::vector argument_layouts(arguments.size()); diff --git a/tensorflow/compiler/xla/tests/local_client_test_base.h b/tensorflow/compiler/xla/tests/local_client_test_base.h index b4477e9a6b23363ee3a1380f9f98f4b8226f6920..4027c7b124f8ac6e4b94600871ac32376a3e6467 100644 --- a/tensorflow/compiler/xla/tests/local_client_test_base.h +++ b/tensorflow/compiler/xla/tests/local_client_test_base.h @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_computation.h" @@ -31,7 +32,6 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/thread_annotations.h" @@ -86,26 +86,25 @@ class LocalClientTestBase : public ::testing::Test { // Construct and return a literal containing the array represented by // shaped_buffer. - std::unique_ptr ShapedBufferToLiteral( - const ShapedBuffer& shaped_buffer); + Literal ShapedBufferToLiteral(const ShapedBuffer& shaped_buffer); // Execute the given computation on the local client. With and without // options. StatusOr ExecuteLocally( const XlaComputation& computation, - tensorflow::gtl::ArraySlice arguments); + absl::Span arguments); StatusOr ExecuteLocally( const XlaComputation& computation, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, const ExecutableBuildOptions& build_options, const ExecutableRunOptions& run_options); ScopedShapedBuffer ExecuteLocallyOrDie( const XlaComputation& computation, - tensorflow::gtl::ArraySlice arguments); + absl::Span arguments); ScopedShapedBuffer ExecuteLocallyOrDie( const XlaComputation& computation, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, const ExecutableBuildOptions& build_options, const ExecutableRunOptions& run_options); diff --git a/tensorflow/compiler/xla/tests/map_test.cc b/tensorflow/compiler/xla/tests/map_test.cc index 0732e195d44d738b264361e43d38259c26a4116e..4d327a6fe9c45174a0666fd573a081e0cfe450d2 100644 --- a/tensorflow/compiler/xla/tests/map_test.cc +++ b/tensorflow/compiler/xla/tests/map_test.cc @@ -169,11 +169,11 @@ class MapTest : public ClientLibraryTestBase { TEST_F(MapTest, MapEachElemPlusOneR0) { // Applies lambda (x) (+ x 1)) to an input scalar. XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = LiteralUtil::CreateR0(42.0); + Literal param0_literal = LiteralUtil::CreateR0(42.0); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); - auto param = Parameter(&builder, 0, param0_literal->shape(), "param0"); + auto param = Parameter(&builder, 0, param0_literal.shape(), "param0"); Map(&builder, {param}, CreateAdderToOne(), {}); ComputeAndCompareR0(&builder, 43.0, {param0_data.get()}, @@ -183,11 +183,11 @@ TEST_F(MapTest, MapEachElemPlusOneR0) { XLA_TEST_F(MapTest, MapEachElemPlusOneR1S0) { // Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 0. XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = LiteralUtil::CreateR1({}); + Literal param0_literal = LiteralUtil::CreateR1({}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); - auto param = Parameter(&builder, 0, param0_literal->shape(), "param0"); + auto param = Parameter(&builder, 0, param0_literal.shape(), "param0"); Map(&builder, {param}, CreateAdderToOne(), {0}); ComputeAndCompareR1(&builder, {}, {param0_data.get()}, @@ -197,12 +197,12 @@ XLA_TEST_F(MapTest, MapEachElemPlusOneR1S0) { TEST_F(MapTest, MapEachElemPlusOneR1S4) { // Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 4. XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = + Literal param0_literal = LiteralUtil::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); - auto param = Parameter(&builder, 0, param0_literal->shape(), "param0"); + auto param = Parameter(&builder, 0, param0_literal.shape(), "param0"); Map(&builder, {param}, CreateAdderToOne(), {0}); ComputeAndCompareR1(&builder, {3.2f, 4.3f, 5.4f, 6.5f}, @@ -211,12 +211,12 @@ TEST_F(MapTest, MapEachElemPlusOneR1S4) { TEST_F(MapTest, MapEachF32ElementToS32Constant) { XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = + Literal param0_literal = LiteralUtil::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); - auto param = Parameter(&builder, 0, param0_literal->shape(), "param0"); + auto param = Parameter(&builder, 0, param0_literal.shape(), "param0"); Map(&builder, {param}, CreateScalarOne(), {0}); ComputeAndCompareR1(&builder, {1, 1, 1, 1}, {param0_data.get()}); @@ -224,12 +224,12 @@ TEST_F(MapTest, MapEachF32ElementToS32Constant) { TEST_F(MapTest, MapEachF32ElementToU32Constant) { XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = + Literal param0_literal = LiteralUtil::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); - auto param = Parameter(&builder, 0, param0_literal->shape(), "param0"); + auto param = Parameter(&builder, 0, param0_literal.shape(), "param0"); Map(&builder, {param}, CreateScalarOne(), {0}); ComputeAndCompareR1(&builder, {1, 1, 1, 1}, {param0_data.get()}); @@ -238,12 +238,12 @@ TEST_F(MapTest, MapEachF32ElementToU32Constant) { TEST_F(MapTest, MapEachElemLongerChainR1) { // Maps (lambda (x) (* (+ x 1) x)) onto an input R1F32 vector. XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = + Literal param0_literal = LiteralUtil::CreateR1({2.6f, -5.1f, 0.1f, 0.2f, 999.0f, 255.5f}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); - auto param = Parameter(&builder, 0, param0_literal->shape(), "param0"); + auto param = Parameter(&builder, 0, param0_literal.shape(), "param0"); Map(&builder, {param}, CreateAdderToOneTimesItself(), {0}); ComputeAndCompareR1( @@ -255,11 +255,11 @@ XLA_TEST_F(MapTest, MapMultipleMapsR1S0) { // Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 0, and then // maps (lambda (x) (* x 2)) on the result. XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = LiteralUtil::CreateR1({}); + Literal param0_literal = LiteralUtil::CreateR1({}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); - auto param = Parameter(&builder, 0, param0_literal->shape(), "param0"); + auto param = Parameter(&builder, 0, param0_literal.shape(), "param0"); auto map1 = Map(&builder, {param}, CreateAdderToOne(), {0}); Map(&builder, {map1}, CreateMulByTwo(), {0}); @@ -271,12 +271,12 @@ TEST_F(MapTest, MapMultipleMapsR1S4) { // Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 4, and then // maps (lambda (x) (* x 2)) on the result. XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = + Literal param0_literal = LiteralUtil::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); - auto param = Parameter(&builder, 0, param0_literal->shape(), "param0"); + auto param = Parameter(&builder, 0, param0_literal.shape(), "param0"); auto map1 = Map(&builder, {param}, CreateAdderToOne(), {0}); Map(&builder, {map1}, CreateMulByTwo(), {0}); @@ -287,12 +287,12 @@ TEST_F(MapTest, MapMultipleMapsR1S4) { TEST_F(MapTest, MapEachElemPlusOneR2) { // Maps (lambda (x) (+ x 1)) onto an input R2F32 vector. XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = LiteralUtil::CreateR2( + Literal param0_literal = LiteralUtil::CreateR2( {{13.25f, 14.0f}, {-7.1f, -7.2f}, {-8.8f, 8.8f}}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); - auto param = Parameter(&builder, 0, param0_literal->shape(), "param0"); + auto param = Parameter(&builder, 0, param0_literal.shape(), "param0"); Map(&builder, {param}, CreateAdderToOne(), {0, 1}); Array2D expected_array( @@ -342,17 +342,17 @@ XLA_TEST_F(MapTest, ComplexNestedMaps) { TEST_F(MapTest, MapBinaryAdder) { // Maps (lambda (x y) (+ x y)) onto two R1F32 vectors. XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = + Literal param0_literal = LiteralUtil::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); - std::unique_ptr param1_literal = + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); + Literal param1_literal = LiteralUtil::CreateR1({5.1f, 4.4f, -0.1f, -5.5f}); std::unique_ptr param1_data = - client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); + client_->TransferToServer(param1_literal).ConsumeValueOrDie(); - auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0"); - auto param1 = Parameter(&builder, 1, param1_literal->shape(), "param1"); + auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0"); + auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1"); Map(&builder, {param0, param1}, CreateScalarAddComputation(F32, &builder), {0}); @@ -365,18 +365,18 @@ TEST_F(MapTest, MapBinaryAdder) { // for Map that used to fail in shape inference (b/28989438). XLA_TEST_F(MapTest, AddWithMixedLayouts) { XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = LiteralUtil::CreateR2WithLayout( + Literal param0_literal = LiteralUtil::CreateR2WithLayout( {{1, 2}, {3, 4}}, LayoutUtil::MakeLayout({1, 0})); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); - std::unique_ptr param1_literal = LiteralUtil::CreateR2WithLayout( + Literal param1_literal = LiteralUtil::CreateR2WithLayout( {{10, 20}, {30, 40}}, LayoutUtil::MakeLayout({0, 1})); std::unique_ptr param1_data = - client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); + client_->TransferToServer(param1_literal).ConsumeValueOrDie(); - auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0"); - auto param1 = Parameter(&builder, 1, param1_literal->shape(), "param1"); + auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0"); + auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1"); Map(&builder, {param0, param1}, CreateScalarAddComputation(S32, &builder), {0, 1}); @@ -391,18 +391,18 @@ XLA_TEST_F(MapTest, AddWithMixedLayouts) { XLA_TEST_F(MapTest, AddR3_3x0x2) { XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = + Literal param0_literal = LiteralUtil::CreateR3FromArray3D(Array3D(3, 0, 2)); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); - std::unique_ptr param1_literal = + Literal param1_literal = LiteralUtil::CreateR3FromArray3D(Array3D(3, 0, 2)); std::unique_ptr param1_data = - client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); + client_->TransferToServer(param1_literal).ConsumeValueOrDie(); - auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0"); - auto param1 = Parameter(&builder, 1, param1_literal->shape(), "param1"); + auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0"); + auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1"); Map(&builder, {param0, param1}, CreateScalarAddComputation(S32, &builder), {0, 1, 2}); @@ -413,22 +413,22 @@ XLA_TEST_F(MapTest, AddR3_3x0x2) { TEST_F(MapTest, MapTernaryAdder) { // Maps (lambda (x y z) (+ x y z)) onto three R1F32 vectors. XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = + Literal param0_literal = LiteralUtil::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); - std::unique_ptr param1_literal = + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); + Literal param1_literal = LiteralUtil::CreateR1({5.1f, 4.4f, -0.1f, -5.5f}); std::unique_ptr param1_data = - client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); - std::unique_ptr param2_literal = + client_->TransferToServer(param1_literal).ConsumeValueOrDie(); + Literal param2_literal = LiteralUtil::CreateR1({-10.0f, -100.0f, -900.0f, -400.0f}); std::unique_ptr param2_data = - client_->TransferToServer(*param2_literal).ConsumeValueOrDie(); + client_->TransferToServer(param2_literal).ConsumeValueOrDie(); - auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0"); - auto param1 = Parameter(&builder, 1, param1_literal->shape(), "param1"); - auto param2 = Parameter(&builder, 2, param2_literal->shape(), "param2"); + auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0"); + auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1"); + auto param2 = Parameter(&builder, 2, param2_literal.shape(), "param2"); Map(&builder, {param0, param1, param2}, CreateTernaryAdder(), {0}); ComputeAndCompareR1( @@ -475,17 +475,17 @@ TEST_F(MapTest, MapOperantionWithBuildError) { Add(x, y); auto error_add = sub_builder->BuildAndNoteError(); - std::unique_ptr param0_literal = + Literal param0_literal = LiteralUtil::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); - std::unique_ptr param1_literal = + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); + Literal param1_literal = LiteralUtil::CreateR1({5.1f, 4.4f, -0.1f, -5.5f}); std::unique_ptr param1_data = - client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); + client_->TransferToServer(param1_literal).ConsumeValueOrDie(); - auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0"); - auto param1 = Parameter(&builder, 1, param1_literal->shape(), "param1"); + auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0"); + auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1"); Map(&builder, {param0, param1}, error_add, {0}); StatusOr computation_status = builder.Build(); @@ -513,15 +513,15 @@ TEST_F(MapTestWithFullOpt, MapScalarPower) { Pow(x, y); auto power = sub_builder->BuildAndNoteError(); - std::unique_ptr param0_literal = LiteralUtil::CreateR0(2.0f); - std::unique_ptr param1_literal = LiteralUtil::CreateR0(5.0f); + Literal param0_literal = LiteralUtil::CreateR0(2.0f); + Literal param1_literal = LiteralUtil::CreateR0(5.0f); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); std::unique_ptr param1_data = - client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); + client_->TransferToServer(param1_literal).ConsumeValueOrDie(); - auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0"); - auto param1 = Parameter(&builder, 1, param1_literal->shape(), "param1"); + auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0"); + auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1"); Map(&builder, {param0, param1}, power, {}); ComputeAndCompareR0(&builder, 32.0f, @@ -540,15 +540,15 @@ TEST_F(MapTestWithFullOpt, MapSubtractOppositeOrder) { Sub(y, x); // note that this is y - x, not x - y auto sub_opposite = sub_builder->BuildAndNoteError(); - std::unique_ptr param0_literal = LiteralUtil::CreateR0(2.0f); - std::unique_ptr param1_literal = LiteralUtil::CreateR0(5.0f); + Literal param0_literal = LiteralUtil::CreateR0(2.0f); + Literal param1_literal = LiteralUtil::CreateR0(5.0f); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); std::unique_ptr param1_data = - client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); + client_->TransferToServer(param1_literal).ConsumeValueOrDie(); - auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0"); - auto param1 = Parameter(&builder, 1, param1_literal->shape(), "param1"); + auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0"); + auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1"); Map(&builder, {param0, param1}, sub_opposite, {}); ComputeAndCompareR0( @@ -565,11 +565,11 @@ TEST_F(MapTestWithFullOpt, MapSquare) { Mul(x, x); auto square = sub_builder->BuildAndNoteError(); - std::unique_ptr param0_literal = LiteralUtil::CreateR0(10.0f); + Literal param0_literal = LiteralUtil::CreateR0(10.0f); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); - auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0"); + auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0"); Map(&builder, {param0}, square, {}); ComputeAndCompareR0(&builder, 100.0f, {param0_data.get()}, diff --git a/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc b/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc index 7956a034f8806bd9f3f50dd4f8e7c2e3405acc0d..3f278115e078877de1683574370df7790c2801fd 100644 --- a/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc +++ b/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_builder.h" @@ -33,7 +34,6 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -63,11 +63,11 @@ XLA_TYPED_TEST(MatOpsSimpleTest_F16F32, ExpTwoByTwoValues) { }); Exp(data); - std::unique_ptr expected = + Literal expected = LiteralUtil::CreateR2FromArray2D({{2.71828f, 1.00000f}, // row 0 {0.36788f, 1.64872f}}); // row 1 - this->ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-5)); + this->ComputeAndCompareLiteral(&builder, expected, {}, ErrorSpec(1e-5)); } XLA_TYPED_TEST(MatOpsSimpleTest_F16F32, MapTwoByTwo) { @@ -92,10 +92,10 @@ XLA_TYPED_TEST(MatOpsSimpleTest_F16F32, MapTwoByTwo) { }); Map(&builder, {data}, add_half, {0, 1}); - std::unique_ptr expected = + Literal expected = LiteralUtil::CreateR2FromArray2D({{1.5f, 0.5f}, // row 0 {-0.5f, 1.0f}}); // row 1 - this->ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-5)); + this->ComputeAndCompareLiteral(&builder, expected, {}, ErrorSpec(1e-5)); } XLA_TYPED_TEST(MatOpsSimpleTest_F16F32, MaxTwoByTwoValues) { @@ -111,10 +111,10 @@ XLA_TYPED_TEST(MatOpsSimpleTest_F16F32, MaxTwoByTwoValues) { }); Max(lhs, rhs); - std::unique_ptr expected = + Literal expected = LiteralUtil::CreateR2FromArray2D({{7.0f, 6.0f}, // row 0 {3.0f, -4.0f}}); // row 1 - this->ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6)); + this->ComputeAndCompareLiteral(&builder, expected, {}, ErrorSpec(1e-6)); } struct TestLinspaceMaxParam { @@ -136,8 +136,7 @@ class TestLinspaceMaxParametric MakeLinspaceArray2D(from, to, rows, cols); auto arhs = absl::make_unique>(rows, cols, static_cast(1.0f)); - XlaBuilder builder( - tensorflow::strings::Printf("max_%lldx%lld_linspace", rows, cols)); + XlaBuilder builder(absl::StrFormat("max_%dx%d_linspace", rows, cols)); auto lhs = ConstantR2FromArray2D(&builder, *alhs); auto rhs = ConstantR2FromArray2D(&builder, *arhs); Max(lhs, rhs); @@ -201,14 +200,12 @@ class MatOpsDotAddTest TF_ASSERT_OK_AND_ASSIGN( auto lhs_handle, - client_->TransferToServer( - *LiteralUtil::CreateR2FromArray2DWithLayout( - lhs, LayoutUtil::MakeLayout(minor_to_major(row_major))))); + client_->TransferToServer(LiteralUtil::CreateR2FromArray2DWithLayout( + lhs, LayoutUtil::MakeLayout(minor_to_major(row_major))))); TF_ASSERT_OK_AND_ASSIGN( auto rhs_handle, - client_->TransferToServer( - *LiteralUtil::CreateR2FromArray2DWithLayout( - rhs, LayoutUtil::MakeLayout(minor_to_major(row_major))))); + client_->TransferToServer(LiteralUtil::CreateR2FromArray2DWithLayout( + rhs, LayoutUtil::MakeLayout(minor_to_major(row_major))))); XlaBuilder builder(TestName()); auto lhs_arg = Parameter(&builder, 0, lhs_shape, "lhs"); diff --git a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc index 16b77e965d11fa136529e70796d11c520962ef28..56aaeb0e6878737e6c689e8065d8f1e1871b3472 100644 --- a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc +++ b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/primitive_util.h" @@ -37,7 +38,6 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/test.h" @@ -47,8 +47,6 @@ limitations under the License. namespace xla { namespace { -using ::tensorflow::gtl::ArraySlice; - class MultiOutputFusionTest : public HloTestBase { protected: MultiOutputFusionTest() { error_spec_ = ErrorSpec{0.0001, 1e-2}; } @@ -91,13 +89,13 @@ class MultiOutputFusionTest : public HloTestBase { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - HloInstruction* dot = builder.AddInstruction( - HloInstruction::CreateDot(elem_shape2, sub, add2, dot_dnums)); + HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot( + elem_shape2, sub, add2, dot_dnums, DefaultPrecisionConfig(2))); auto computation = hlo_module->AddEntryComputation(builder.Build(dot)); if (manual_fusion) { - auto tuple = computation->AddInstruction(HloInstruction::CreateTuple( - ArraySlice({sub, add2}, 0, 2))); + auto tuple = + computation->AddInstruction(HloInstruction::CreateTuple({sub, add2})); auto gte0 = computation->AddInstruction( HloInstruction::CreateGetTupleElement(elem_shape2, tuple, 0)); auto gte1 = computation->AddInstruction( @@ -116,10 +114,10 @@ class MultiOutputFusionTest : public HloTestBase { Literal expect(ShapeUtil::MakeShapeWithDescendingLayout(F32, {size, size})); expect.PopulateWithValue(size * 1.5f * 3.5f); + Literal literal_r0 = LiteralUtil::CreateR0(-9.0f); auto actual = - ExecuteAndTransfer(std::move(hlo_module), - {LiteralUtil::CreateR0(-9.0f).get(), &arg1}); - EXPECT_TRUE(LiteralTestUtil::Near(expect, *actual, error_spec_)); + ExecuteAndTransfer(std::move(hlo_module), {&literal_r0, &arg1}); + EXPECT_TRUE(LiteralTestUtil::Near(expect, actual, error_spec_)); } void RunTest1D(bool manual_fusion, int size) { @@ -155,12 +153,12 @@ class MultiOutputFusionTest : public HloTestBase { dot_dnums.add_rhs_contracting_dimensions(0); HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot( ShapeUtil::MakeShapeWithDescendingLayout(F32, {1}), sub, reshape, - dot_dnums)); + dot_dnums, DefaultPrecisionConfig(2))); auto computation = hlo_module->AddEntryComputation(builder.Build(dot)); if (manual_fusion) { - auto tuple = computation->AddInstruction(HloInstruction::CreateTuple( - ArraySlice({sub_U8, add}, 0, 2))); + auto tuple = computation->AddInstruction( + HloInstruction::CreateTuple({sub_U8, add})); auto gte0 = computation->AddInstruction( HloInstruction::CreateGetTupleElement(elem_shape_U8, tuple, 0)); @@ -180,10 +178,9 @@ class MultiOutputFusionTest : public HloTestBase { Literal input1(ShapeUtil::MakeShapeWithDescendingLayout(F64, {size})); input1.PopulateWithValue(1.); - Literal expect = - std::move(*LiteralUtil::CreateR1({size * 1.5f * 3.5f})); + Literal expect = LiteralUtil::CreateR1({size * 1.5f * 3.5f}); auto actual = ExecuteAndTransfer(std::move(hlo_module), {&input0, &input1}); - EXPECT_TRUE(LiteralTestUtil::Near(expect, *actual, error_spec_)); + EXPECT_TRUE(LiteralTestUtil::Near(expect, actual, error_spec_)); } }; @@ -220,10 +217,9 @@ XLA_TEST_F(MultiOutputFusionTest, FusionNodeIsRoot) { LiteralUtil::CreateR0(1.0)), LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR0(3.0), LiteralUtil::CreateR0(4))); - std::unique_ptr result = - ExecuteNoHloPasses(std::move(module), {param.get()}); + Literal result = ExecuteNoHloPasses(std::move(module), {¶m}); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR0(42)), *result)); + LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR0(42)), result)); } XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFusion) { @@ -249,9 +245,8 @@ XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFusion) { HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) .ValueOrDie(); auto param = LiteralUtil::CreateR1({1.0, 2.0, 3.0, -1.0}); - std::unique_ptr result = - ExecuteNoHloPasses(std::move(module), {param.get()}); - LiteralTestUtil::ExpectR1Equal({0.0, 4.0, 9.0, 1.0}, *result); + Literal result = ExecuteNoHloPasses(std::move(module), {¶m}); + LiteralTestUtil::ExpectR1Equal({0.0, 4.0, 9.0, 1.0}, result); } XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFeedingMap) { @@ -282,9 +277,8 @@ XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFeedingMap) { HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) .ValueOrDie(); auto param = LiteralUtil::CreateR1({1.0, 2.0, 3.0}); - std::unique_ptr result = - ExecuteNoHloPasses(std::move(module), {param.get()}); - LiteralTestUtil::ExpectR1Equal({0.0, 4.0, 9.0}, *result); + Literal result = ExecuteNoHloPasses(std::move(module), {¶m}); + LiteralTestUtil::ExpectR1Equal({0.0, 4.0, 9.0}, result); } const char* const kScalarOps = R"( @@ -326,13 +320,12 @@ XLA_TEST_F(MultiOutputFusionTest, .ValueOrDie(); auto param = LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); - std::unique_ptr result = - ExecuteNoHloPasses(std::move(module), {param.get()}); + Literal result = ExecuteNoHloPasses(std::move(module), {¶m}); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::MakeTupleOwned( + LiteralUtil::MakeTupleOwned( LiteralUtil::CreateR2({{3, 7}, {11, 15}}), LiteralUtil::CreateR2({{5, 16}, {36, 64}})), - *result)); + result)); } XLA_TEST_F(MultiOutputFusionTest, @@ -358,13 +351,12 @@ XLA_TEST_F(MultiOutputFusionTest, .ValueOrDie(); auto param = LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); - std::unique_ptr result = - ExecuteNoHloPasses(std::move(module), {param.get()}); + Literal result = ExecuteNoHloPasses(std::move(module), {¶m}); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::MakeTupleOwned( + LiteralUtil::MakeTupleOwned( LiteralUtil::CreateR2({{6, 8}, {10, 12}}), LiteralUtil::CreateR2({{25, 36}, {49, 64}})), - *result)); + result)); } XLA_TEST_F(MultiOutputFusionTest, @@ -391,13 +383,12 @@ XLA_TEST_F(MultiOutputFusionTest, .ValueOrDie(); auto param = LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); - std::unique_ptr result = - ExecuteNoHloPasses(std::move(module), {param.get()}); + Literal result = ExecuteNoHloPasses(std::move(module), {¶m}); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1({14, 22}), - LiteralUtil::CreateR1({36, 64}), - LiteralUtil::CreateR1({66, 138})), - *result)); + LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1({14, 22}), + LiteralUtil::CreateR1({36, 64}), + LiteralUtil::CreateR1({66, 138})), + result)); } XLA_TEST_F(MultiOutputFusionTest, @@ -424,14 +415,13 @@ XLA_TEST_F(MultiOutputFusionTest, .ValueOrDie(); auto param = LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); - std::unique_ptr result = - ExecuteNoHloPasses(std::move(module), {param.get()}); + Literal result = ExecuteNoHloPasses(std::move(module), {¶m}); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::MakeTupleOwned( + LiteralUtil::MakeTupleOwned( LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}), LiteralUtil::CreateR2({{3, 7}, {11, 15}}), LiteralUtil::CreateR2({{5, 16}, {36, 64}})), - *result)); + result)); } XLA_TEST_F(MultiOutputFusionTest, @@ -458,15 +448,14 @@ XLA_TEST_F(MultiOutputFusionTest, .ValueOrDie(); auto param = LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); - std::unique_ptr result = - ExecuteNoHloPasses(std::move(module), {param.get()}); + Literal result = ExecuteNoHloPasses(std::move(module), {¶m}); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::MakeTupleOwned( + LiteralUtil::MakeTupleOwned( LiteralUtil::CreateR2({{6, 8}, {10, 12}}), LiteralUtil::CreateR3( {{{1, 4}, {9, 16}}, {{25, 36}, {49, 64}}}), LiteralUtil::CreateR2({{25, 36}, {49, 64}})), - *result)); + result)); } XLA_TEST_F(MultiOutputFusionTest, @@ -494,16 +483,15 @@ XLA_TEST_F(MultiOutputFusionTest, .ValueOrDie(); auto param = LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); - std::unique_ptr result = - ExecuteNoHloPasses(std::move(module), {param.get()}); + Literal result = ExecuteNoHloPasses(std::move(module), {¶m}); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::MakeTupleOwned( + LiteralUtil::MakeTupleOwned( LiteralUtil::CreateR1({14, 22}), LiteralUtil::CreateR3( {{{1, 4}, {9, 16}}, {{25, 36}, {49, 64}}}), LiteralUtil::CreateR3( {{{5, 10}, {15, 20}}, {{25, 30}, {35, 40}}})), - *result)); + result)); } XLA_TEST_F(MultiOutputFusionTest, @@ -532,13 +520,13 @@ XLA_TEST_F(MultiOutputFusionTest, LiteralUtil::CreateR3({{{0, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); auto init1 = LiteralUtil::CreateR0(5); auto init2 = LiteralUtil::CreateR0(6); - std::unique_ptr result = ExecuteNoHloPasses( - std::move(module), {param.get(), init1.get(), init2.get()}); + Literal result = + ExecuteNoHloPasses(std::move(module), {¶m, &init1, &init2}); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::MakeTupleOwned( + LiteralUtil::MakeTupleOwned( LiteralUtil::CreateR2({{167, 172}, {176, 180}}), LiteralUtil::CreateR2({{6, 6}, {6, 8}})), - *result)); + result)); } XLA_TEST_F(MultiOutputFusionTest, @@ -567,10 +555,9 @@ XLA_TEST_F(MultiOutputFusionTest, auto param = LiteralUtil::CreateR3( {{{Eigen::half(1), Eigen::half(2)}, {Eigen::half(3), Eigen::half(4)}}, {{Eigen::half(5), Eigen::half(6)}, {Eigen::half(7), Eigen::half(8)}}}); - std::unique_ptr result = - ExecuteNoHloPasses(std::move(module), {param.get()}); + Literal result = ExecuteNoHloPasses(std::move(module), {¶m}); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::MakeTupleOwned( + LiteralUtil::MakeTupleOwned( LiteralUtil::CreateR2({{3, 7}, {11, 15}}), LiteralUtil::CreateR2({{5, 16}, {36, 64}}), LiteralUtil::CreateR3( @@ -578,7 +565,7 @@ XLA_TEST_F(MultiOutputFusionTest, {Eigen::half(3), Eigen::half(4)}}, {{Eigen::half(5), Eigen::half(6)}, {Eigen::half(7), Eigen::half(8)}}})), - *result)); + result)); } } // namespace diff --git a/tensorflow/compiler/xla/tests/multiple_devices_on_host_test.cc b/tensorflow/compiler/xla/tests/multiple_devices_on_host_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..c530591c6e5fe75658dd507d794f8b6a64442871 --- /dev/null +++ b/tensorflow/compiler/xla/tests/multiple_devices_on_host_test.cc @@ -0,0 +1,120 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "absl/synchronization/mutex.h" +#include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace { +StatusOr BuildComputation() { + XlaBuilder b("computation"); + Shape scalar_s32 = ShapeUtil::MakeShape(S32, {}); + XlaOp infeed = InfeedWithToken(CreateToken(&b), scalar_s32); + return b.Build( + OutfeedWithToken(GetTupleElement(infeed, 0) + + ConstantLiteral(&b, LiteralUtil::CreateR0(1)), + GetTupleElement(infeed, 1), scalar_s32, "")); +} + +void CompileAndExecute( + LocalExecutable* executable, int device_ordinal, LocalClient* client, + absl::Mutex* results_mutex, + std::vector>>* results) { + xla::ExecutableRunOptions execute_options; + execute_options.set_intra_op_thread_pool( + client->backend().eigen_intra_op_thread_pool_device()); + execute_options.set_device_ordinal(device_ordinal); + execute_options.set_allocator( + xla::ClientLibrary::GetXlaService(client->platform()) + ->backend() + .memory_allocator()); + StatusOr result = executable->Run({}, execute_options); + { + absl::MutexLock lock(results_mutex); + results->emplace_back(device_ordinal, std::move(result)); + } +} + +void TestWithDeviceCount(const int device_count) { + // Run `device_count` copies of the XLA program built by BuildComputation. + TF_ASSERT_OK_AND_ASSIGN( + se::Platform* const platform, + perftools::gputools::MultiPlatformManager::PlatformWithName("Host")); + xla::LocalClientOptions client_options; + client_options.set_platform(platform); + TF_ASSERT_OK_AND_ASSIGN( + LocalClient* const client, + xla::ClientLibrary::GetOrCreateLocalClient(client_options)); + + TF_ASSERT_OK_AND_ASSIGN(XlaComputation xla_computation, BuildComputation()); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr executable, + client->Compile(xla_computation, {}, xla::ExecutableBuildOptions{})); + std::vector threads; + absl::Mutex results_mutex; + std::vector>> results; + tensorflow::Env* env = tensorflow::Env::Default(); + for (int device_ordinal = 0; device_ordinal < device_count; + device_ordinal++) { + tensorflow::Thread* t = env->StartThread( + tensorflow::ThreadOptions{}, absl::StrCat("thread-", device_ordinal), + [&executable, device_ordinal, client, &results_mutex, &results] { + CompileAndExecute(executable.get(), device_ordinal, client, + &results_mutex, &results); + }); + threads.push_back(t); + } + + for (int device_ordinal = 0; device_ordinal < device_count; + device_ordinal++) { + TF_ASSERT_OK(client->TransferToInfeedLocal( + LiteralUtil::CreateR0(device_ordinal * 100), device_ordinal)); + } + + for (int device_ordinal = 0; device_ordinal < device_count; + device_ordinal++) { + TF_ASSERT_OK_AND_ASSIGN(Literal outfeed, + client->TransferFromOutfeedLocal( + ShapeUtil::MakeShape(S32, {}), device_ordinal)); + EXPECT_EQ(outfeed, LiteralUtil::CreateR0(device_ordinal * 100 + 1)); + } + + for (int device_ordinal = 0; device_ordinal < device_count; + device_ordinal++) { + delete threads[device_ordinal]; + } + + for (int device_ordinal = 0; device_ordinal < device_count; + device_ordinal++) { + TF_ASSERT_OK(results[device_ordinal].second.status()); + } +} + +// NB! This test requires --xla_force_host_platform_device_count=4 + +TEST(MultipleDeviceOnHostTest, OneDevice) { TestWithDeviceCount(1); } + +TEST(MultipleDeviceOnHostTest, TwoDevices) { TestWithDeviceCount(2); } + +TEST(MultipleDeviceOnHostTest, ThreeDevices) { TestWithDeviceCount(3); } + +TEST(MultipleDeviceOnHostTest, FourDevices) { TestWithDeviceCount(4); } +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/outfeed_in_nested_computation_test.cc b/tensorflow/compiler/xla/tests/outfeed_in_nested_computation_test.cc index 0a0426adcbc1b5b89be0841fa2c4204e2b65abf4..f2460822a61fef11e5c35c731fa6eca5df72b60b 100644 --- a/tensorflow/compiler/xla/tests/outfeed_in_nested_computation_test.cc +++ b/tensorflow/compiler/xla/tests/outfeed_in_nested_computation_test.cc @@ -70,7 +70,7 @@ XLA_TEST_F(OutfeedInNestedComputationTest, OutfeedInWhile) { GetTupleElement(result_tuple, 0); TF_ASSERT_OK_AND_ASSIGN(XlaComputation computation, b.Build()); - std::unique_ptr comp_result; + Literal comp_result; std::unique_ptr thread( tensorflow::Env::Default()->StartThread( tensorflow::ThreadOptions(), "execute_thread", [&] { @@ -81,41 +81,41 @@ XLA_TEST_F(OutfeedInNestedComputationTest, OutfeedInWhile) { VLOG(1) << "Transferring trip count to computation"; // Transfer number of iterations to Infeed. TF_ASSERT_OK( - local_client_->TransferToInfeed(*LiteralUtil::CreateR0(1))); + local_client_->TransferToInfeed(LiteralUtil::CreateR0(1))); // Pick up value from outfeed { VLOG(1) << "Reading from condition outfeed"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr r, + TF_ASSERT_OK_AND_ASSIGN(Literal r, local_client_->TransferFromOutfeed(&int_shape)); - EXPECT_EQ(r->Get({}), 1); + EXPECT_EQ(r.Get({}), 1); } VLOG(1) << "Writing data to infeed"; // Transfer some stuff to Infeed for use inside of loop. TF_ASSERT_OK(local_client_->TransferToInfeed( - *LiteralUtil::CreateR1({10, 20}))); + LiteralUtil::CreateR1({10, 20}))); // Pick up value from outfeed { VLOG(1) << "Reading from body outfeed"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr r, + TF_ASSERT_OK_AND_ASSIGN(Literal r, local_client_->TransferFromOutfeed(&xfeed_shape)); - EXPECT_EQ(r->Get({0}), 11); - EXPECT_EQ(r->Get({1}), 21); + EXPECT_EQ(r.Get({0}), 11); + EXPECT_EQ(r.Get({1}), 21); } { VLOG(1) << "Reading from condition outfeed"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr r, + TF_ASSERT_OK_AND_ASSIGN(Literal r, local_client_->TransferFromOutfeed(&int_shape)); - EXPECT_EQ(r->Get({}), 0); + EXPECT_EQ(r.Get({}), 0); } // Joins the thread thread.reset(); - EXPECT_EQ(comp_result->Get({}), 0); + EXPECT_EQ(comp_result.Get({}), 0); } XLA_TEST_F(OutfeedInNestedComputationTest, OutfeedInConditional) { @@ -145,7 +145,7 @@ XLA_TEST_F(OutfeedInNestedComputationTest, OutfeedInConditional) { TF_ASSERT_OK_AND_ASSIGN(XlaComputation computation, b.Build()); - std::unique_ptr comp_result; + Literal comp_result; std::unique_ptr thread( tensorflow::Env::Default()->StartThread( tensorflow::ThreadOptions(), "execute_thread", [&] { @@ -154,12 +154,12 @@ XLA_TEST_F(OutfeedInNestedComputationTest, OutfeedInConditional) { })); TF_ASSERT_OK( - local_client_->TransferToInfeed(*LiteralUtil::CreateR0(true))); + local_client_->TransferToInfeed(LiteralUtil::CreateR0(true))); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr r, + TF_ASSERT_OK_AND_ASSIGN(Literal r, local_client_->TransferFromOutfeed(&result_shape)); - EXPECT_EQ(r->Get({}), true); + EXPECT_EQ(r.Get({}), true); // Join the thread thread.reset(); diff --git a/tensorflow/compiler/xla/tests/pad_test.cc b/tensorflow/compiler/xla/tests/pad_test.cc index cbeddffacfa4a0fc560e8b9f9a8d7bd23ff32e55..6e98167739c234fae335bcc9e024423e7fc87197 100644 --- a/tensorflow/compiler/xla/tests/pad_test.cc +++ b/tensorflow/compiler/xla/tests/pad_test.cc @@ -93,8 +93,8 @@ XLA_TEST_P(PadTestFloat, Pad1DS0ToS0Array) { dimension->set_edge_padding_high(0); dimension->set_interior_padding(0); - Pad(AddParam(*LiteralUtil::CreateR1({}), &b), - AddParam(*LiteralUtil::CreateR0(0.1), &b), padding_config); + Pad(AddParam(LiteralUtil::CreateR1({}), &b), + AddParam(LiteralUtil::CreateR0(0.1), &b), padding_config); ComputeAndCompareR1(&b, {}, {}, DefaultErrorSpec()); } @@ -108,8 +108,8 @@ XLA_TEST_P(PadTestFloat, Pad1DS0ToS5Array) { dimension->set_edge_padding_high(4); dimension->set_interior_padding(7); - Pad(AddParam(*LiteralUtil::CreateR1({}), &b), - AddParam(*LiteralUtil::CreateR0(0.1), &b), padding_config); + Pad(AddParam(LiteralUtil::CreateR1({}), &b), + AddParam(LiteralUtil::CreateR0(0.1), &b), padding_config); ComputeAndCompareR1(&b, std::vector(5, 0.1), {}, DefaultErrorSpec()); } @@ -123,8 +123,8 @@ XLA_TEST_P(PadTestFloat, Pad1DS3Array) { dimension->set_edge_padding_high(0); dimension->set_interior_padding(1); - Pad(AddParam(*LiteralUtil::CreateR1({1, 2, 3}), &b), - AddParam(*LiteralUtil::CreateR0(0.1), &b), padding_config); + Pad(AddParam(LiteralUtil::CreateR1({1, 2, 3}), &b), + AddParam(LiteralUtil::CreateR0(0.1), &b), padding_config); std::vector expected({0.1, 0.1, 0.1, 1, 0.1, 2, 0.1, 3}); ComputeAndCompareR1(&b, expected, {}, DefaultErrorSpec()); } @@ -132,7 +132,7 @@ XLA_TEST_P(PadTestFloat, Pad1DS3Array) { XLA_TEST_P(PadTestFloat, Pad4D_2x0x3x2_FloatArray) { XlaBuilder b(TestName()); Pad(AddParam(Array4D(2, 0, 3, 2), &b), - AddParam(*LiteralUtil::CreateR0(1.5), &b), + AddParam(LiteralUtil::CreateR0(1.5), &b), r4_padding_on_dim0_dim1_); ComputeAndCompareR4(&b, Array4D(5, 2, 3, 2, 1.5f), {}, DefaultErrorSpec()); @@ -148,7 +148,7 @@ TEST_P(PadTestFloat, Pad4DFloat_1x1x3x2_Array) { }); input->FillWithYX(input_xy); - Pad(AddParam(*input, &b), AddParam(*LiteralUtil::CreateR0(1.5), &b), + Pad(AddParam(*input, &b), AddParam(LiteralUtil::CreateR0(1.5), &b), r4_padding_on_dim0_dim1_); auto expected = absl::make_unique>(2, 3, 3, 2); @@ -168,7 +168,7 @@ TEST_P(PadTestFloat, Pad4DFloatArrayWithInteriorPadding) { const float pad_value = 1.5f; Array4D input(3, 2, 1, 1, {1, 2, 3, 4, 5, 6}); Pad(AddParam(input, &b), - AddParam(*LiteralUtil::CreateR0(pad_value), &b), + AddParam(LiteralUtil::CreateR0(pad_value), &b), r4_padding_on_dim0_dim1_); auto expected = absl::make_unique>(8, 5, 1, 1); @@ -208,10 +208,10 @@ TEST_P(PadTestFloat, Pad4DFloatArrayMinorFirstSmall) { const float pad_value = -5.123f; Array4D input_array(1, 1, 2, 3, {1, 2, 3, 4, 5, 6}); auto input = LiteralUtil::CreateR4FromArray4D(input_array); - input = input->Relayout(layout); + input = input.Relayout(layout); - Pad(AddParam(*input, &b), - AddParam(*LiteralUtil::CreateR0(pad_value), &b), padding_config); + Pad(AddParam(input, &b), + AddParam(LiteralUtil::CreateR0(pad_value), &b), padding_config); Array4D expected_array(1, 1, 5, 8); expected_array.Fill(pad_value); @@ -254,10 +254,10 @@ XLA_TEST_P(PadTestFloat, Pad4DFloatArrayMinorFirstNonTrivialMinorDimensions) { input_array(0, 24, 6, 6) = 2.0f; input_array(0, 17, 2, 5) = 3.0f; auto input = LiteralUtil::CreateR4FromArray4D(input_array); - input = input->Relayout(layout); + input = input.Relayout(layout); - Pad(AddParam(*input, &b), - AddParam(*LiteralUtil::CreateR0(pad_value), &b), padding_config); + Pad(AddParam(input, &b), + AddParam(LiteralUtil::CreateR0(pad_value), &b), padding_config); Array4D expected_array(1, 25, 17, 11); expected_array.Fill(pad_value); @@ -331,7 +331,7 @@ XLA_TEST_P(PadTestFloat, Large2DPad) { padding_config.mutable_dimensions(dim)->set_edge_padding_high(58 + 100 * dim); } - Pad(input, AddParam(*LiteralUtil::CreateR0(0.0f), &b), padding_config); + Pad(input, AddParam(LiteralUtil::CreateR0(0.0f), &b), padding_config); auto expected = ReferenceUtil::PadArray2D(*ones, padding_config, 0.0f); ComputeAndCompareR2(&b, *expected, {}, DefaultErrorSpec()); @@ -353,8 +353,7 @@ XLA_TEST_P(PadTestFloat, AllTypes2DPad) { padding_config.mutable_dimensions(1)->set_edge_padding_low(6); padding_config.mutable_dimensions(1)->set_edge_padding_high(4); padding_config.mutable_dimensions(1)->set_interior_padding(2); - Pad(input, AddParam(*LiteralUtil::CreateR0(3.14f), &b), - padding_config); + Pad(input, AddParam(LiteralUtil::CreateR0(3.14f), &b), padding_config); auto expected = ReferenceUtil::PadArray2D(*operand, padding_config, 3.14f); ComputeAndCompareR2(&b, *expected, {}, DefaultErrorSpec()); @@ -379,7 +378,7 @@ XLA_TEST_P(PadTestFloat, High2DPad) { padding_config.mutable_dimensions(dim)->set_interior_padding( interior_padding); } - Pad(input, AddParam(*LiteralUtil::CreateR0(2.718f), &b), + Pad(input, AddParam(LiteralUtil::CreateR0(2.718f), &b), padding_config); auto expected = ReferenceUtil::PadArray2D(*operand, padding_config, 2.718f); @@ -407,7 +406,7 @@ XLA_TEST_P(PadTestFloat, NegativePadding2D) { padding_config.mutable_dimensions(dim)->set_interior_padding( interior_padding); } - Pad(input, AddParam(*LiteralUtil::CreateR0(2.718f), &b), + Pad(input, AddParam(LiteralUtil::CreateR0(2.718f), &b), padding_config); auto expected = ReferenceUtil::PadArray2D(*operand, padding_config, 2.718f); @@ -435,7 +434,7 @@ XLA_TEST_P(PadTestFloat, NegativeAndInteriorPadding2D) { padding_config.mutable_dimensions(dim)->set_interior_padding( interior_padding[dim]); } - Pad(input, AddParam(*LiteralUtil::CreateR0(2.718f), &b), + Pad(input, AddParam(LiteralUtil::CreateR0(2.718f), &b), padding_config); auto expected = ReferenceUtil::PadArray2D(*operand, padding_config, 2.718f); @@ -452,13 +451,12 @@ XLA_TEST_P(PadTestFloat, ReducePad) { XlaComputation add = CreateScalarAddComputation(FloatType(), &b); auto reduce = - Reduce(input, AddParam(*LiteralUtil::CreateR0(0.0), &b), add, {0}); + Reduce(input, AddParam(LiteralUtil::CreateR0(0.0), &b), add, {0}); PaddingConfig padding_config = MakeNoPaddingConfig(3); padding_config.mutable_dimensions(0)->set_edge_padding_low(1); padding_config.mutable_dimensions(0)->set_edge_padding_high(1); - Pad(reduce, AddParam(*LiteralUtil::CreateR0(0.0f), &b), - padding_config); + Pad(reduce, AddParam(LiteralUtil::CreateR0(0.0f), &b), padding_config); Array3D expected({{{0.0, 0.0}, {0.0, 0.0}}, {{2.0, 2.0}, {2.0, 2.0}}, diff --git a/tensorflow/compiler/xla/tests/params_test.cc b/tensorflow/compiler/xla/tests/params_test.cc index f6c762e7a4bee91a26c4c2e033c3717fef6d91d0..dcb4c11c3ccab5992e1ea4fadf02fda8ff77e7ea 100644 --- a/tensorflow/compiler/xla/tests/params_test.cc +++ b/tensorflow/compiler/xla/tests/params_test.cc @@ -42,10 +42,9 @@ class ParamsTest : public ClientLibraryTestBase {}; XLA_TEST_F(ParamsTest, ConstantR0F32Param) { XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = - LiteralUtil::CreateR0(3.14159f); + Literal param0_literal = LiteralUtil::CreateR0(3.14159f); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "param0"); @@ -55,9 +54,9 @@ XLA_TEST_F(ParamsTest, ConstantR0F32Param) { XLA_TEST_F(ParamsTest, ConstantR1S0F32Param) { XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = LiteralUtil::CreateR1({}); + Literal param0_literal = LiteralUtil::CreateR1({}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {0}), "param0"); @@ -67,10 +66,9 @@ XLA_TEST_F(ParamsTest, ConstantR1S0F32Param) { XLA_TEST_F(ParamsTest, ConstantR1S2F32Param) { XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = - LiteralUtil::CreateR1({3.14f, -100.25f}); + Literal param0_literal = LiteralUtil::CreateR1({3.14f, -100.25f}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2}), "param0"); @@ -81,9 +79,9 @@ XLA_TEST_F(ParamsTest, ConstantR1S2F32Param) { XLA_TEST_F(ParamsTest, ConstantR1U8Param) { XlaBuilder builder(TestName()); string str("hello world"); - std::unique_ptr param0_literal = LiteralUtil::CreateR1U8(str); + Literal param0_literal = LiteralUtil::CreateR1U8(str); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); Parameter(&builder, 0, ShapeUtil::MakeShape(U8, {static_cast(str.size())}), @@ -94,10 +92,10 @@ XLA_TEST_F(ParamsTest, ConstantR1U8Param) { XLA_TEST_F(ParamsTest, ConstantR2_3x0_F32Param) { XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = + Literal param0_literal = LiteralUtil::CreateR2FromArray2D(Array2D(3, 0)); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {3, 0}), "param0"); @@ -107,10 +105,10 @@ XLA_TEST_F(ParamsTest, ConstantR2_3x0_F32Param) { XLA_TEST_F(ParamsTest, ConstantR2F32Param) { XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = LiteralUtil::CreateR2( + Literal param0_literal = LiteralUtil::CreateR2( {{3.14f, -100.25f}, {7e8f, 7e-9f}, {30.3f, -100.0f}}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {3, 2}), "param0"); @@ -123,15 +121,15 @@ XLA_TEST_F(ParamsTest, ConstantR2F32Param) { XLA_TEST_F(ParamsTest, TwoParameters) { XlaBuilder builder(TestName()); - std::unique_ptr literal0 = LiteralUtil::CreateR1({1, 2}); + Literal literal0 = LiteralUtil::CreateR1({1, 2}); std::unique_ptr param0_data = - client_->TransferToServer(*literal0).ConsumeValueOrDie(); - auto param0 = Parameter(&builder, 0, literal0->shape(), "param0"); + client_->TransferToServer(literal0).ConsumeValueOrDie(); + auto param0 = Parameter(&builder, 0, literal0.shape(), "param0"); - std::unique_ptr literal1 = LiteralUtil::CreateR1({10, 20}); + Literal literal1 = LiteralUtil::CreateR1({10, 20}); std::unique_ptr param1_data = - client_->TransferToServer(*literal1).ConsumeValueOrDie(); - auto param1 = Parameter(&builder, 1, literal1->shape(), "param1"); + client_->TransferToServer(literal1).ConsumeValueOrDie(); + auto param1 = Parameter(&builder, 1, literal1.shape(), "param1"); // Use both parameters // @@ -154,9 +152,9 @@ XLA_TEST_F(ParamsTest, TwoParameters) { XLA_TEST_F(ParamsTest, MissingParameter) { // Test that an error is returned when a computation with an incomplete set of // parameters (parameter numbers not contiguous from 0) is executed. - std::unique_ptr literal = LiteralUtil::CreateR0(3.14159f); + Literal literal = LiteralUtil::CreateR0(3.14159f); std::unique_ptr data = - client_->TransferToServer(*literal).ConsumeValueOrDie(); + client_->TransferToServer(literal).ConsumeValueOrDie(); XlaBuilder builder(TestName()); Parameter(&builder, 2, ShapeUtil::MakeShape(F32, {}), "param2"); @@ -168,15 +166,15 @@ XLA_TEST_F(ParamsTest, MissingParameter) { XLA_TEST_F(ParamsTest, UnusedParameter) { XlaBuilder builder(TestName()); - std::unique_ptr literal0 = LiteralUtil::CreateR1({1, 2}); + Literal literal0 = LiteralUtil::CreateR1({1, 2}); std::unique_ptr param0_data = - client_->TransferToServer(*literal0).ConsumeValueOrDie(); - Parameter(&builder, 0, literal0->shape(), "param0"); + client_->TransferToServer(literal0).ConsumeValueOrDie(); + Parameter(&builder, 0, literal0.shape(), "param0"); - std::unique_ptr literal1 = LiteralUtil::CreateR1({10, 20}); + Literal literal1 = LiteralUtil::CreateR1({10, 20}); std::unique_ptr param1_data = - client_->TransferToServer(*literal1).ConsumeValueOrDie(); - Parameter(&builder, 1, literal1->shape(), "param1"); + client_->TransferToServer(literal1).ConsumeValueOrDie(); + Parameter(&builder, 1, literal1.shape(), "param1"); ComputeAndCompareR1(&builder, {10, 20}, {param0_data.get(), param1_data.get()}, @@ -188,18 +186,17 @@ XLA_TEST_F(ParamsTest, UnusedParametersInUnusedExpression) { // unused expression. XlaBuilder builder(TestName()); - std::unique_ptr literal0 = LiteralUtil::CreateR1({1, 2}); + Literal literal0 = LiteralUtil::CreateR1({1, 2}); std::unique_ptr param0_data = - client_->TransferToServer(*literal0).ConsumeValueOrDie(); + client_->TransferToServer(literal0).ConsumeValueOrDie(); - std::unique_ptr literal1 = - LiteralUtil::CreateR1({10, 20, 30}); + Literal literal1 = LiteralUtil::CreateR1({10, 20, 30}); std::unique_ptr param1_data = - client_->TransferToServer(*literal1).ConsumeValueOrDie(); + client_->TransferToServer(literal1).ConsumeValueOrDie(); - auto param0 = Parameter(&builder, 0, literal0->shape(), "param0"); - auto param1 = Parameter(&builder, 1, literal1->shape(), "param1"); - auto param2 = Parameter(&builder, 2, literal1->shape(), "param2"); + auto param0 = Parameter(&builder, 0, literal0.shape(), "param0"); + auto param1 = Parameter(&builder, 1, literal1.shape(), "param1"); + auto param2 = Parameter(&builder, 2, literal1.shape(), "param2"); // This add is unused. Add(param1, param2); @@ -233,10 +230,10 @@ XLA_TEST_F(ParamsTest, HundredLargeR1Parameters) { std::vector sum_value = {{entry0, entry1}}; sum_value.resize(size); - std::unique_ptr literal = LiteralUtil::CreateR1(sum_value); + Literal literal = LiteralUtil::CreateR1(sum_value); param_data_owner.push_back( - client_->TransferToServer(*literal).ConsumeValueOrDie()); - XlaOp param = Parameter(&builder, i, literal->shape(), "param"); + client_->TransferToServer(literal).ConsumeValueOrDie()); + XlaOp param = Parameter(&builder, i, literal.shape(), "param"); sum_handle = Add(sum_handle, param); } @@ -268,10 +265,10 @@ XLA_TEST_F(ParamsTest, constexpr int kParamCount = 3000; for (int i = 0; i < kParamCount; ++i) { target += i; - std::unique_ptr literal = LiteralUtil::CreateR0(i); + Literal literal = LiteralUtil::CreateR0(i); param_data_owner.push_back( - std::move(client_->TransferToServer(*literal)).ValueOrDie()); - XlaOp param = Parameter(&builder, i, literal->shape(), "param"); + std::move(client_->TransferToServer(literal)).ValueOrDie()); + XlaOp param = Parameter(&builder, i, literal.shape(), "param"); sum_handle = Add(sum_handle, param); } @@ -300,10 +297,10 @@ XLA_TEST_F(ParamsTest, DISABLED_ON_CPU(DISABLED_ON_GPU( std::vector params; for (int i = 0; i < kParamCount; ++i) { target += i; - std::unique_ptr literal = LiteralUtil::CreateR1({i, i}); + Literal literal = LiteralUtil::CreateR1({i, i}); param_data_owner.push_back( - std::move(client_->TransferToServer(*literal)).ValueOrDie()); - XlaOp param = Parameter(&builder, i, literal->shape(), "param"); + std::move(client_->TransferToServer(literal)).ValueOrDie()); + XlaOp param = Parameter(&builder, i, literal.shape(), "param"); params.push_back(param); sum_handle = Add(sum_handle, param); } @@ -321,13 +318,14 @@ XLA_TEST_F(ParamsTest, DISABLED_ON_CPU(DISABLED_ON_GPU( param_data.push_back(data.get()); } - std::vector> elements; + std::vector elements; std::vector ptrs; + elements.reserve(kParamCount); for (int i = 0; i < kParamCount; ++i) { elements.push_back(LiteralUtil::CreateR1({target + i, target + i})); - ptrs.push_back(elements.back().get()); + ptrs.push_back(&elements.back()); } - ComputeAndCompareTuple(&builder, *LiteralUtil::MakeTuple(ptrs), param_data); + ComputeAndCompareTuple(&builder, LiteralUtil::MakeTuple(ptrs), param_data); } // Test large number of parameters flowing into a while-loop. @@ -356,23 +354,23 @@ XLA_TEST_F(ParamsTest, std::vector params; std::vector parameter_shapes; for (int i = 0; i < kParamCount; ++i) { - std::unique_ptr literal = LiteralUtil::CreateR1({i, i}); + Literal literal = LiteralUtil::CreateR1({i, i}); param_data_owner.push_back( - std::move(client_->TransferToServer(*literal)).ValueOrDie()); - XlaOp param = Parameter(&builder, i, literal->shape(), "param"); + std::move(client_->TransferToServer(literal)).ValueOrDie()); + XlaOp param = Parameter(&builder, i, literal.shape(), "param"); params.push_back(param); - parameter_shapes.push_back(literal->shape()); + parameter_shapes.push_back(literal.shape()); } // Add bool parameter for the loop condition. Use a parameter HLO instead of a // constant because DCE may eliminate the while-body otherwise. - std::unique_ptr bool_literal = LiteralUtil::CreateR0(false); + Literal bool_literal = LiteralUtil::CreateR0(false); param_data_owner.push_back( - std::move(client_->TransferToServer(*bool_literal)).ValueOrDie()); + std::move(client_->TransferToServer(bool_literal)).ValueOrDie()); XlaOp bool_param = - Parameter(&builder, kParamCount, bool_literal->shape(), "bool_param"); + Parameter(&builder, kParamCount, bool_literal.shape(), "bool_param"); params.push_back(bool_param); - parameter_shapes.push_back(bool_literal->shape()); + parameter_shapes.push_back(bool_literal.shape()); auto init = Tuple(&builder, params); @@ -420,13 +418,14 @@ XLA_TEST_F(ParamsTest, param_data.push_back(data.get()); } - std::vector> elements; + std::vector elements; std::vector ptrs; + elements.reserve(kParamCount); for (int i = 0; i < kParamCount; ++i) { elements.push_back(LiteralUtil::CreateR1({i, i})); - ptrs.push_back(elements.back().get()); + ptrs.push_back(&elements.back()); } - ComputeAndCompareTuple(&builder, *LiteralUtil::MakeTuple(ptrs), param_data); + ComputeAndCompareTuple(&builder, LiteralUtil::MakeTuple(ptrs), param_data); } #endif @@ -443,9 +442,9 @@ XLA_TEST_F(ParamsTest, TupleOfR1ParametersAddedTogether) { std::unique_ptr data = client_ - ->TransferToServer(*LiteralUtil::MakeTuple({ - LiteralUtil::CreateR1({1, 2, 3}).get(), - LiteralUtil::CreateR1({4, 5, 6}).get(), + ->TransferToServer(LiteralUtil::MakeTupleFromSlices({ + LiteralUtil::CreateR1({1, 2, 3}), + LiteralUtil::CreateR1({4, 5, 6}), })) .ConsumeValueOrDie(); @@ -457,34 +456,34 @@ XLA_TEST_F(ParamsTest, TupleOfR1ParametersAddedTogether) { // Verifies that passing a 2x2 with {0, 1} layout returns the same value back // when (transferred to the server and) passed through a parameter. XLA_TEST_F(ParamsTest, R2_2x2_Layout_01) { - std::unique_ptr literal = LiteralUtil::CreateR2WithLayout( + Literal literal = LiteralUtil::CreateR2WithLayout( {{1, 2}, {3, 4}}, LayoutUtil::MakeLayout({0, 1})); XlaBuilder builder(TestName()); - Parameter(&builder, 0, literal->shape(), "input"); + Parameter(&builder, 0, literal.shape(), "input"); std::unique_ptr data = - client_->TransferToServer(*literal).ConsumeValueOrDie(); - ComputeAndCompareLiteral(&builder, *literal, {data.get()}, ErrorSpec(1e-3)); + client_->TransferToServer(literal).ConsumeValueOrDie(); + ComputeAndCompareLiteral(&builder, literal, {data.get()}, ErrorSpec(1e-3)); } // As above, but for {1, 0} layout. XLA_TEST_F(ParamsTest, R2_2x2_Layout_10) { - std::unique_ptr literal = LiteralUtil::CreateR2WithLayout( + Literal literal = LiteralUtil::CreateR2WithLayout( {{1, 3}, {2, 4}}, LayoutUtil::MakeLayout({1, 0})); XlaBuilder builder(TestName()); - Parameter(&builder, 0, literal->shape(), "input"); + Parameter(&builder, 0, literal.shape(), "input"); std::unique_ptr data = - client_->TransferToServer(*literal).ConsumeValueOrDie(); - ComputeAndCompareLiteral(&builder, *literal, {data.get()}, ErrorSpec(1e-3)); + client_->TransferToServer(literal).ConsumeValueOrDie(); + ComputeAndCompareLiteral(&builder, literal, {data.get()}, ErrorSpec(1e-3)); } XLA_TEST_F(ParamsTest, R2_2x2_TryToPassReverseLayoutToParameter) { - std::unique_ptr literal = LiteralUtil::CreateR2({ + Literal literal = LiteralUtil::CreateR2({ {1, 3}, {2, 4}, }); - const Shape original = literal->shape(); + const Shape original = literal.shape(); { // Reverse the layout present in original, and make that the layout of the // literal. @@ -492,9 +491,9 @@ XLA_TEST_F(ParamsTest, R2_2x2_TryToPassReverseLayoutToParameter) { original.layout().minor_to_major().begin(), original.layout().minor_to_major().end()); std::reverse(original_layout.begin(), original_layout.end()); - *literal->mutable_shape_do_not_use()->mutable_layout() = + *literal.mutable_shape_do_not_use()->mutable_layout() = LayoutUtil::MakeLayout(original_layout); - ASSERT_EQ(2, literal->Get({0, 1})); + ASSERT_EQ(2, literal.Get({0, 1})); } // Use the original shape in building the computation. XlaBuilder builder(TestName()); @@ -503,7 +502,7 @@ XLA_TEST_F(ParamsTest, R2_2x2_TryToPassReverseLayoutToParameter) { Slice(input, {0, 1}, {1, 2}, {1, 1}); std::unique_ptr data = - client_->TransferToServer(*literal).ConsumeValueOrDie(); + client_->TransferToServer(literal).ConsumeValueOrDie(); // Check that we got the off-diagonal value that we expected. Array2D expected(1, 1); expected(0, 0) = 2; diff --git a/tensorflow/compiler/xla/tests/pred_test.cc b/tensorflow/compiler/xla/tests/pred_test.cc index 2fc7f816b56db6f57ca835d1847476b6d622ce5e..58539e6b061b0cec1cc660b52e78894e5deeea56 100644 --- a/tensorflow/compiler/xla/tests/pred_test.cc +++ b/tensorflow/compiler/xla/tests/pred_test.cc @@ -31,7 +31,7 @@ class PredTest : public ClientLibraryTestBase { protected: void TestCompare(bool lhs, bool rhs, bool expected, std::function)> + absl::Span)> op) { XlaBuilder builder(TestName()); XlaOp lhs_op = ConstantR0(&builder, lhs); diff --git a/tensorflow/compiler/xla/tests/prng_test.cc b/tensorflow/compiler/xla/tests/prng_test.cc index 326e13b3867f2f804e882e00e35850d0189ad8d7..8f2c26f0eea9c7a3b33cd77e5977924c1659535a 100644 --- a/tensorflow/compiler/xla/tests/prng_test.cc +++ b/tensorflow/compiler/xla/tests/prng_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" @@ -26,7 +27,6 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -37,9 +37,7 @@ namespace { class PrngTest : public ClientLibraryTestBase { protected: template - std::unique_ptr UniformTest(T a, T b, - tensorflow::gtl::ArraySlice dims, - int64 seed = 42); + Literal UniformTest(T a, T b, absl::Span dims, int64 seed = 42); // Computes the χ² statistic of a sample of the discrete uniform distribution // of the given range size. `expected_count` is the number of times each @@ -50,8 +48,8 @@ class PrngTest : public ClientLibraryTestBase { }; template -std::unique_ptr PrngTest::UniformTest( - T a, T b, tensorflow::gtl::ArraySlice dims, int64 seed) { +Literal PrngTest::UniformTest(T a, T b, absl::Span dims, + int64 seed) { XlaBuilder builder(TestName()); RngUniform( ConstantR0(&builder, a), ConstantR0(&builder, b), @@ -60,8 +58,8 @@ std::unique_ptr PrngTest::UniformTest( SetSeed(seed); auto actual = ExecuteAndTransfer(&builder, /*arguments=*/{}).ConsumeValueOrDie(); - EXPECT_THAT(dims, ::testing::ElementsAreArray(actual->shape().dimensions())); - actual->EachCell([=](tensorflow::gtl::ArraySlice, T value) { + EXPECT_THAT(dims, ::testing::ElementsAreArray(actual.shape().dimensions())); + actual.EachCell([=](absl::Span, T value) { EXPECT_LE(a, value); EXPECT_LT(value, b); }); @@ -116,11 +114,10 @@ XLA_TEST_F(PrngTest, DISABLED_ON_GPU(DISABLED_ON_CPU(ScalarBF16CountTests))) { constexpr int64 count = 100; for (int64 seed = 0; seed < count; ++seed) { auto result = UniformTest(low, high, {}, /*seed=*/seed); - result->Literal::EachCell( - [&](tensorflow::gtl::ArraySlice, bfloat16 value) { - int64 index = static_cast((value - low) / interval); - counts[index]++; - }); + result.EachCell([&](absl::Span, bfloat16 value) { + int64 index = static_cast((value - low) / interval); + counts[index]++; + }); } // Each bucket should have similar amount of counts. That is, not more than // 10% of total counts. This mostly tests that we don't fall into a 1:2:2 @@ -149,8 +146,8 @@ double PrngTest::UniformChiSquared(int32 range_size, int32 expected_count, auto actual = ExecuteAndTransfer(&builder, /*arguments=*/{}).ConsumeValueOrDie(); std::vector counts(range_size, 0); - actual->EachCell([&counts](tensorflow::gtl::ArraySlice, - int32 value) { ++counts[value]; }); + actual.EachCell( + [&counts](absl::Span, int32 value) { ++counts[value]; }); int64 sum = 0; for (int32 i = 0; i < range_size; ++i) { sum += Square(static_cast(counts[i] - expected_count)); @@ -192,12 +189,12 @@ XLA_TEST_F(PrngTest, MapUsingRng) { }; XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = + Literal param0_literal = LiteralUtil::CreateR1({2.2f, 5.3f, 4.4f, 5.5f}); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr param0_data, - client_->TransferToServer(*param0_literal)); + client_->TransferToServer(param0_literal)); - auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0"); + auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0"); auto fn = build_sum_rng(builder); Map(&builder, {param0}, fn, {0}); @@ -210,12 +207,11 @@ XLA_TEST_F(PrngTest, MapUsingRng) { computation, /*arguments=*/{param0_data.get()}, &execution_options)); - EXPECT_EQ(ShapeUtil::ElementsIn(actual->shape()), - ShapeUtil::ElementsIn(param0_literal->shape())); - for (int i = 0; i < ShapeUtil::ElementsIn(actual->shape()); ++i) { - EXPECT_GE(actual->data()[i], param0_literal->data()[i]); - EXPECT_LT(actual->data()[i], - param0_literal->data()[i] + 1.0f); + EXPECT_EQ(ShapeUtil::ElementsIn(actual.shape()), + ShapeUtil::ElementsIn(param0_literal.shape())); + for (int i = 0; i < ShapeUtil::ElementsIn(actual.shape()); ++i) { + EXPECT_GE(actual.data()[i], param0_literal.data()[i]); + EXPECT_LT(actual.data()[i], param0_literal.data()[i] + 1.0f); } } @@ -238,15 +234,15 @@ XLA_TEST_F(PrngTest, PassInGlobalRngSeed) { ExecutionOptions execution_options2 = execution_options_; execution_options2.set_seed(65); - std::unique_ptr result1; + Literal result1; { TF_ASSERT_OK_AND_ASSIGN(auto computation, build_computation()); TF_ASSERT_OK_AND_ASSIGN( result1, client_->ExecuteAndTransfer(computation, /*arguments=*/{}, &execution_options1)); } - std::unique_ptr result2; - std::unique_ptr result3; + Literal result2; + Literal result3; { TF_ASSERT_OK_AND_ASSIGN(auto computation, build_computation()); TF_ASSERT_OK_AND_ASSIGN( @@ -257,9 +253,9 @@ XLA_TEST_F(PrngTest, PassInGlobalRngSeed) { &execution_options1)); } - std::unique_ptr result4; - std::unique_ptr result5; - std::unique_ptr result6; + Literal result4; + Literal result5; + Literal result6; { TF_ASSERT_OK_AND_ASSIGN(auto computation, build_computation()); TF_ASSERT_OK_AND_ASSIGN( @@ -273,11 +269,11 @@ XLA_TEST_F(PrngTest, PassInGlobalRngSeed) { &execution_options_)); } - EXPECT_TRUE(LiteralTestUtil::Equal(*result1, *result2)); - EXPECT_TRUE(LiteralTestUtil::Equal(*result1, *result3)); - EXPECT_FALSE(LiteralTestUtil::Equal(*result1, *result4)); - EXPECT_FALSE(LiteralTestUtil::Equal(*result4, *result5)); - EXPECT_FALSE(LiteralTestUtil::Equal(*result5, *result6)); + EXPECT_TRUE(LiteralTestUtil::Equal(result1, result2)); + EXPECT_TRUE(LiteralTestUtil::Equal(result1, result3)); + EXPECT_FALSE(LiteralTestUtil::Equal(result1, result4)); + EXPECT_FALSE(LiteralTestUtil::Equal(result4, result5)); + EXPECT_FALSE(LiteralTestUtil::Equal(result5, result6)); } XLA_TEST_F(PrngTest, TenValuesN01) { diff --git a/tensorflow/compiler/xla/tests/reduce_hlo_test.cc b/tensorflow/compiler/xla/tests/reduce_hlo_test.cc index 9af9ea4a2229bb6ca7c3561350f11837f5072a2c..c9096fb29b2019796c42b69de80c63b5fc7c5c3a 100644 --- a/tensorflow/compiler/xla/tests/reduce_hlo_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_hlo_test.cc @@ -92,7 +92,7 @@ XLA_TEST_P(ReduceWithLayoutTest, DISABLED_ON_GPU(Reduce)) { *reduce_input_shape->mutable_layout() = LayoutUtil::MakeLayout(reduce_layout.input_minor_to_major); - std::unique_ptr reduce_input = LiteralUtil::CreateR4( + Literal reduce_input = LiteralUtil::CreateR4( {{ /*i0=0*/ {/*i1=0*/ {-0.246092796, -0.179497838, -0.161181688}, diff --git a/tensorflow/compiler/xla/tests/reduce_precision_test.cc b/tensorflow/compiler/xla/tests/reduce_precision_test.cc index 0916a07f4fa99af6cf25441fa8558a558bfa032f..26e2bfde5cdc19657640f24f31bc008d09ad7106 100644 --- a/tensorflow/compiler/xla/tests/reduce_precision_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_precision_test.cc @@ -231,11 +231,10 @@ XLA_TEST_P(ReducePrecisionAccuracyTest, ReducePrecisionF32) { XlaBuilder builder(TestName()); - std::unique_ptr a_literal = - LiteralUtil::CreateR1({input_values}); + Literal a_literal = LiteralUtil::CreateR1({input_values}); std::unique_ptr a_data = - client_->TransferToServer(*a_literal).ConsumeValueOrDie(); - auto a = Parameter(&builder, 0, a_literal->shape(), "a"); + client_->TransferToServer(a_literal).ConsumeValueOrDie(); + auto a = Parameter(&builder, 0, a_literal.shape(), "a"); ReducePrecision(a, exponent_bits, mantissa_bits); @@ -255,10 +254,10 @@ XLA_TEST_F(ReducePrecisionInsertionTest, DISABLED_ON_INTERPRETER(ReducePrecisionBeforeFusion)) { XlaBuilder builder(TestName()); - std::unique_ptr a_literal = LiteralUtil::CreateR1({1.00001}); + Literal a_literal = LiteralUtil::CreateR1({1.00001}); std::unique_ptr a_data = - client_->TransferToServer(*a_literal).ConsumeValueOrDie(); - auto a = Parameter(&builder, 0, a_literal->shape(), "a"); + client_->TransferToServer(a_literal).ConsumeValueOrDie(); + auto a = Parameter(&builder, 0, a_literal.shape(), "a"); // Abs doesn't affect resolution. auto abs = Abs(a); @@ -284,10 +283,10 @@ XLA_TEST_F(ReducePrecisionInsertionTest, DISABLED_ON_INTERPRETER(ReducePrecisionSkippedAfterFusion)) { XlaBuilder builder(TestName()); - std::unique_ptr a_literal = LiteralUtil::CreateR1({1.00001}); + Literal a_literal = LiteralUtil::CreateR1({1.00001}); std::unique_ptr a_data = - client_->TransferToServer(*a_literal).ConsumeValueOrDie(); - auto a = Parameter(&builder, 0, a_literal->shape(), "a"); + client_->TransferToServer(a_literal).ConsumeValueOrDie(); + auto a = Parameter(&builder, 0, a_literal.shape(), "a"); // These two operations should be fused by any reasonable backend. auto abs = Abs(a); @@ -310,10 +309,10 @@ XLA_TEST_F(ReducePrecisionInsertionTest, DISABLED_ON_INTERPRETER(ReducePrecisionAddedAfterFusion)) { XlaBuilder builder(TestName()); - std::unique_ptr a_literal = LiteralUtil::CreateR1({1.00001}); + Literal a_literal = LiteralUtil::CreateR1({1.00001}); std::unique_ptr a_data = - client_->TransferToServer(*a_literal).ConsumeValueOrDie(); - auto a = Parameter(&builder, 0, a_literal->shape(), "a"); + client_->TransferToServer(a_literal).ConsumeValueOrDie(); + auto a = Parameter(&builder, 0, a_literal.shape(), "a"); // These two operations should be fused by any reasonable backend. auto abs = Abs(a); @@ -334,10 +333,10 @@ XLA_TEST_F(ReducePrecisionInsertionTest, DISABLED_ON_INTERPRETER(ReducePrecisionSkippedFusionContains)) { XlaBuilder builder(TestName()); - std::unique_ptr a_literal = LiteralUtil::CreateR1({1.00001}); + Literal a_literal = LiteralUtil::CreateR1({1.00001}); std::unique_ptr a_data = - client_->TransferToServer(*a_literal).ConsumeValueOrDie(); - auto a = Parameter(&builder, 0, a_literal->shape(), "a"); + client_->TransferToServer(a_literal).ConsumeValueOrDie(); + auto a = Parameter(&builder, 0, a_literal.shape(), "a"); // These two operations should be fused by any reasonable backend. auto abs = Abs(a); @@ -359,10 +358,10 @@ XLA_TEST_F(ReducePrecisionInsertionTest, DISABLED_ON_INTERPRETER(ReducePrecisionAddedFusionContains)) { XlaBuilder builder(TestName()); - std::unique_ptr a_literal = LiteralUtil::CreateR1({1.00001}); + Literal a_literal = LiteralUtil::CreateR1({1.00001}); std::unique_ptr a_data = - client_->TransferToServer(*a_literal).ConsumeValueOrDie(); - auto a = Parameter(&builder, 0, a_literal->shape(), "a"); + client_->TransferToServer(a_literal).ConsumeValueOrDie(); + auto a = Parameter(&builder, 0, a_literal.shape(), "a"); // These two operations should be fused by any reasonable backend. auto abs = Abs(a); diff --git a/tensorflow/compiler/xla/tests/reduce_test.cc b/tensorflow/compiler/xla/tests/reduce_test.cc index b93d838349d90d34d1792529456cdbd58d40b8fd..83997cdac21c437d460dabdbdfdb31100b1359af 100644 --- a/tensorflow/compiler/xla/tests/reduce_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_test.cc @@ -32,7 +32,9 @@ limitations under the License. #include #include +#include "absl/strings/str_format.h" #include "absl/strings/str_join.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/global_data.h" @@ -52,7 +54,6 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -80,9 +81,9 @@ class ReduceTest : public ClientLibraryTestBase { }, 4); // clang-format on CHECK(ShapeUtil::Equal( - literal_3d_->shape(), + literal_3d_.shape(), ShapeUtil::MakeShape(F32, {/*z=*/4, /*y=*/2, /*x=*/3}))) - << literal_3d_->shape().ShortDebugString(); + << literal_3d_.shape().ShortDebugString(); } // Runs an R1 => R0 reduction test with the given number of elements. @@ -101,10 +102,9 @@ class ReduceTest : public ClientLibraryTestBase { input_data[i] *= -1; } } - std::unique_ptr input_literal = - LiteralUtil::CreateR1(AsSlice(input_data)); + Literal input_literal = LiteralUtil::CreateR1(AsSlice(input_data)); std::unique_ptr input_global_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + client_->TransferToServer(input_literal).ConsumeValueOrDie(); float expected = 0.0; for (float item : input_data) { @@ -114,8 +114,7 @@ class ReduceTest : public ClientLibraryTestBase { ErrorSpec(0.001)); } - void RunR1ToR0PredTest(bool and_reduce, - tensorflow::gtl::ArraySlice input_data) { + void RunR1ToR0PredTest(bool and_reduce, absl::Span input_data) { const int element_count = input_data.size(); XlaBuilder builder(TestName()); const Shape input_shape = ShapeUtil::MakeShape(S32, {element_count}); @@ -134,9 +133,9 @@ class ReduceTest : public ClientLibraryTestBase { Reduce(pred_values, init_value, reduce, /*dimensions_to_reduce=*/{0}); - std::unique_ptr input_literal = LiteralUtil::CreateR1(input_data); + Literal input_literal = LiteralUtil::CreateR1(input_data); std::unique_ptr input_global_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + client_->TransferToServer(input_literal).ConsumeValueOrDie(); bool expected = and_reduce; for (bool item : input_data) { @@ -175,12 +174,11 @@ class ReduceTest : public ClientLibraryTestBase { Array2D input_data(rows, cols); input_data.FillRandom(0, 1); - std::unique_ptr input_literal = - LiteralUtil::CreateR2FromArray2D(input_data); + Literal input_literal = LiteralUtil::CreateR2FromArray2D(input_data); input_literal = - input_literal->Relayout(LayoutUtil::MakeLayout({minor, major})); + input_literal.Relayout(LayoutUtil::MakeLayout({minor, major})); std::unique_ptr input_global_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + client_->TransferToServer(input_literal).ConsumeValueOrDie(); std::array expected; for (int64 colno = 0; colno < cols; ++colno) { @@ -209,12 +207,11 @@ class ReduceTest : public ClientLibraryTestBase { Array2D input_data(rows, cols); input_data.FillRandom(3.14f, 0.04); - std::unique_ptr input_literal = - LiteralUtil::CreateR2FromArray2D(input_data); + Literal input_literal = LiteralUtil::CreateR2FromArray2D(input_data); input_literal = - input_literal->Relayout(LayoutUtil::MakeLayout({minor, major})); + input_literal.Relayout(LayoutUtil::MakeLayout({minor, major})); std::unique_ptr input_global_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + client_->TransferToServer(input_literal).ConsumeValueOrDie(); float expected = 0.0; for (int64 rowno = 0; rowno < rows; ++rowno) { @@ -237,12 +234,11 @@ class ReduceTest : public ClientLibraryTestBase { Array2D input_data(rows, cols); input_data.FillRandom(3.14f, 0.04); - std::unique_ptr input_literal = - LiteralUtil::CreateR2FromArray2D(input_data); + Literal input_literal = LiteralUtil::CreateR2FromArray2D(input_data); input_literal = - input_literal->Relayout(LayoutUtil::MakeLayout({minor, major})); + input_literal.Relayout(LayoutUtil::MakeLayout({minor, major})); std::unique_ptr input_global_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + client_->TransferToServer(input_literal).ConsumeValueOrDie(); std::vector expected; for (int64 colno = 0; colno < cols; ++colno) { @@ -260,8 +256,8 @@ class ReduceTest : public ClientLibraryTestBase { void ComputeAndCompareGeneric( typename std::enable_if::value, XlaBuilder>::type* builder, - tensorflow::gtl::ArraySlice expected, - tensorflow::gtl::ArraySlice arguments) { + absl::Span expected, + absl::Span arguments) { ComputeAndCompareR1(builder, expected, arguments, ErrorSpec(0.01, 1e-4)); } @@ -270,8 +266,8 @@ class ReduceTest : public ClientLibraryTestBase { void ComputeAndCompareGeneric( typename std::enable_if::value, XlaBuilder>::type* builder, - tensorflow::gtl::ArraySlice expected, - tensorflow::gtl::ArraySlice arguments) { + absl::Span expected, + absl::Span arguments) { ComputeAndCompareR1(builder, expected, arguments); } @@ -295,15 +291,14 @@ class ReduceTest : public ClientLibraryTestBase { Array2D input_data(rows, cols); input_data.FillUnique(initial_value); - std::unique_ptr input_literal = - LiteralUtil::CreateR2FromArray2D(input_data); + Literal input_literal = LiteralUtil::CreateR2FromArray2D(input_data); input_literal = - input_literal->Relayout(LayoutUtil::MakeLayout({minor, major})); + input_literal.Relayout(LayoutUtil::MakeLayout({minor, major})); std::unique_ptr input_global_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + client_->TransferToServer(input_literal).ConsumeValueOrDie(); // NativeT can be bool, and std::vector does not convert to - // ArraySlice. + // Span. std::unique_ptr expected(new NativeT[cols]); for (int64 colno = 0; colno < cols; ++colno) { NativeT column_result = initial_value; @@ -315,7 +310,7 @@ class ReduceTest : public ClientLibraryTestBase { } ComputeAndCompareGeneric( - &builder, tensorflow::gtl::ArraySlice(expected.get(), cols), + &builder, absl::Span(expected.get(), cols), {input_global_data.get()}); } @@ -352,8 +347,8 @@ class ReduceTest : public ClientLibraryTestBase { reference_reduction_function_for_uints, unsigned_int_identity); } - std::unique_ptr literal_2d_; - std::unique_ptr literal_3d_; + Literal literal_2d_; + Literal literal_3d_; uint32 seed_ = 0xdeadbeef; }; @@ -450,11 +445,10 @@ XLA_TEST_F(ReduceTest, ReduceElementwiseR2_111x50_To_R1) { Array2D input_data(rows, cols); input_data.FillRandom(3.14f, 0.04); - std::unique_ptr input_literal = - LiteralUtil::CreateR2FromArray2D(input_data); - input_literal = input_literal->Relayout(LayoutUtil::MakeLayout({0, 1})); + Literal input_literal = LiteralUtil::CreateR2FromArray2D(input_data); + input_literal = input_literal.Relayout(LayoutUtil::MakeLayout({0, 1})); std::unique_ptr input_global_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + client_->TransferToServer(input_literal).ConsumeValueOrDie(); std::vector expected; for (int64 colno = 0; colno < cols; ++colno) { @@ -482,11 +476,10 @@ XLA_TEST_F(ReduceTest, TransposeAndReduceElementwiseR2_111x50_To_R1) { Array2D input_data(rows, cols); input_data.FillRandom(3.14f, 0.04); - std::unique_ptr input_literal = - LiteralUtil::CreateR2FromArray2D(input_data); - input_literal = input_literal->Relayout(LayoutUtil::MakeLayout({0, 1})); + Literal input_literal = LiteralUtil::CreateR2FromArray2D(input_data); + input_literal = input_literal.Relayout(LayoutUtil::MakeLayout({0, 1})); std::unique_ptr input_global_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + client_->TransferToServer(input_literal).ConsumeValueOrDie(); std::vector expected; for (int64 colno = 0; colno < cols; ++colno) { @@ -511,10 +504,9 @@ XLA_TEST_F(ReduceTest, TransposeAndReduceR3_12x111x50_To_R2) { XlaOp transpose = Transpose(input, /*permutation=*/{1, 0, 2}); Reduce(transpose, zero, add_f32, /*dimensions_to_reduce=*/{0}); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr input_data, - MakeFakeLiteral(input_shape)); + TF_ASSERT_OK_AND_ASSIGN(Literal input_data, MakeFakeLiteral(input_shape)); - ComputeAndCompare(&builder, {std::move(*input_data)}, ErrorSpec(0.01, 1e-4)); + ComputeAndCompare(&builder, {std::move(input_data)}, ErrorSpec(0.01, 1e-4)); } XLA_TEST_F(ReduceTest, Reshape_111x2x25Reduce_111x50_To_R1) { @@ -531,10 +523,9 @@ XLA_TEST_F(ReduceTest, Reshape_111x2x25Reduce_111x50_To_R1) { Array3D input_data(rows, 2, cols / 2); input_data.FillRandom(3.14f, 0.04); - std::unique_ptr input_literal = - LiteralUtil::CreateR3FromArray3D(input_data); + Literal input_literal = LiteralUtil::CreateR3FromArray3D(input_data); std::unique_ptr input_global_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + client_->TransferToServer(input_literal).ConsumeValueOrDie(); std::vector expected; for (int64 major = 0; major < 2; ++major) { @@ -557,12 +548,11 @@ struct BoundsLayout { }; void PrintTo(const BoundsLayout& spec, std::ostream* os) { - *os << tensorflow::strings::Printf( - "R%luToR%lu%s_%s_Reduce%s", spec.bounds.size(), - spec.bounds.size() - spec.reduce_dims.size(), - absl::StrJoin(spec.bounds, "x").c_str(), - absl::StrJoin(spec.layout, "").c_str(), - absl::StrJoin(spec.reduce_dims, "").c_str()); + *os << absl::StrFormat("R%uToR%u%s_%s_Reduce%s", spec.bounds.size(), + spec.bounds.size() - spec.reduce_dims.size(), + absl::StrJoin(spec.bounds, "x"), + absl::StrJoin(spec.layout, ""), + absl::StrJoin(spec.reduce_dims, "")); } // Add-reduces a broadcasted scalar matrix among dimension 1 and 0. @@ -596,7 +586,7 @@ XLA_TEST_F(ReduceTest, MaxReduce2DToR0) { Array2D input(300, 250); input.FillRandom(214.0f); auto input_literal = LiteralUtil::CreateR2FromArray2D(input); - Reduce(ConstantLiteral(&builder, *input_literal), + Reduce(ConstantLiteral(&builder, input_literal), ConstantR0(&builder, FLT_MIN), max, {0, 1}); auto input_max = FLT_MIN; input.Each( @@ -611,7 +601,7 @@ XLA_TEST_F(ReduceTest, MinReduce2DToR0) { Array2D input(150, 130); input.FillRandom(214.0f); auto input_literal = LiteralUtil::CreateR2FromArray2D(input); - Reduce(ConstantLiteral(&builder, *input_literal), + Reduce(ConstantLiteral(&builder, input_literal), ConstantR0(&builder, FLT_MAX), min, {0, 1}); auto input_min = FLT_MAX; @@ -628,7 +618,7 @@ XLA_TEST_F(ReduceTest, UnsignedInt_MinReduce) { auto initial_value = ConstantR0(&builder, std::numeric_limits::max()); - Reduce(ConstantLiteral(&builder, *input_literal), initial_value, min, {0, 1}); + Reduce(ConstantLiteral(&builder, input_literal), initial_value, min, {0, 1}); ComputeAndCompareR0(&builder, 1, {}); } @@ -640,14 +630,14 @@ XLA_TEST_F(ReduceTest, UnsignedInt_MaxReduce) { auto initial_value = ConstantR0(&builder, std::numeric_limits::min()); - Reduce(ConstantLiteral(&builder, *input_literal), initial_value, max, {0, 1}); + Reduce(ConstantLiteral(&builder, input_literal), initial_value, max, {0, 1}); ComputeAndCompareR0(&builder, 2, {}); } // Reduces a matrix among dimension 1. XLA_TEST_F(ReduceTest, Reduce2DAmong1) { XlaBuilder builder(TestName()); - auto m = ConstantLiteral(&builder, *literal_2d_); + auto m = ConstantLiteral(&builder, literal_2d_); auto add = CreateScalarAddComputation(F32, &builder); Reduce(m, ConstantR0(&builder, 0.0f), add, {1}); @@ -658,7 +648,7 @@ XLA_TEST_F(ReduceTest, Reduce2DAmong1) { XLA_TEST_F(ReduceTest, Reduce2DAmong0and1) { // Reduce a matrix among dimensions 0 and 1 (sum it up to a scalar). XlaBuilder builder(TestName()); - auto m = ConstantLiteral(&builder, *literal_2d_); + auto m = ConstantLiteral(&builder, literal_2d_); auto add = CreateScalarAddComputation(F32, &builder); Reduce(m, ConstantR0(&builder, 0.0f), add, {0, 1}); @@ -668,7 +658,7 @@ XLA_TEST_F(ReduceTest, Reduce2DAmong0and1) { // Tests 2D matrix ReduceToRow operation. XLA_TEST_F(ReduceTest, Reduce2DAmongY) { XlaBuilder builder("reduce_among_y"); - auto m = ConstantLiteral(&builder, *literal_2d_); + auto m = ConstantLiteral(&builder, literal_2d_); auto add = CreateScalarAddComputation(F32, &builder); Reduce(m, ConstantR0(&builder, 0.0f), add, {0}); @@ -678,7 +668,7 @@ XLA_TEST_F(ReduceTest, Reduce2DAmongY) { XLA_TEST_F(ReduceTest, ReduceR3AmongDims_1_2) { XlaBuilder builder(TestName()); - auto m = ConstantLiteral(&builder, *literal_3d_); + auto m = ConstantLiteral(&builder, literal_3d_); auto add = CreateScalarAddComputation(F32, &builder); Reduce(m, ConstantR0(&builder, 0.0f), add, {1, 2}); @@ -688,7 +678,7 @@ XLA_TEST_F(ReduceTest, ReduceR3AmongDims_1_2) { XLA_TEST_F(ReduceTest, ReduceR3AmongDims_0_1) { XlaBuilder builder(TestName()); - auto m = ConstantLiteral(&builder, *literal_3d_); + auto m = ConstantLiteral(&builder, literal_3d_); auto add = CreateScalarAddComputation(F32, &builder); Reduce(m, ConstantR0(&builder, 0.0f), add, {0, 1}); @@ -698,7 +688,7 @@ XLA_TEST_F(ReduceTest, ReduceR3AmongDims_0_1) { XLA_TEST_F(ReduceTest, ReduceR3ToR0) { XlaBuilder builder(TestName()); - auto m = ConstantLiteral(&builder, *literal_3d_); + auto m = ConstantLiteral(&builder, literal_3d_); auto add = CreateScalarAddComputation(F32, &builder); Reduce(m, ConstantR0(&builder, 0.0f), add, {0, 1, 2}); @@ -708,7 +698,7 @@ XLA_TEST_F(ReduceTest, ReduceR3ToR0) { XLA_TEST_F(ReduceTest, ReduceR3AmongDim0) { XlaBuilder builder(TestName()); - auto m = ConstantLiteral(&builder, *literal_3d_); + auto m = ConstantLiteral(&builder, literal_3d_); auto add = CreateScalarAddComputation(F32, &builder); Reduce(m, ConstantR0(&builder, 0.0f), add, {0}); @@ -723,7 +713,7 @@ XLA_TEST_F(ReduceTest, ReduceR3AmongDim0) { XLA_TEST_F(ReduceTest, ReduceR3AmongDim1) { XlaBuilder builder(TestName()); - auto m = ConstantLiteral(&builder, *literal_3d_); + auto m = ConstantLiteral(&builder, literal_3d_); auto add = CreateScalarAddComputation(F32, &builder); Reduce(m, ConstantR0(&builder, 0.0f), add, {1}); @@ -740,7 +730,7 @@ XLA_TEST_F(ReduceTest, ReduceR3AmongDim1) { XLA_TEST_F(ReduceTest, ReduceR3AmongDim2) { XlaBuilder builder(TestName()); - auto m = ConstantLiteral(&builder, *literal_3d_); + auto m = ConstantLiteral(&builder, literal_3d_); auto add = CreateScalarAddComputation(F32, &builder); Reduce(m, ConstantR0(&builder, 0.0f), add, {2}); @@ -825,12 +815,12 @@ XLA_TEST_P(ReduceR3ToR2Test, ReduceR3ToR2) { auto input_literal = LiteralUtil::CreateR3FromArray3D(input_array); input_literal = - input_literal->Relayout(LayoutUtil::MakeLayout(GetParam().layout)); + input_literal.Relayout(LayoutUtil::MakeLayout(GetParam().layout)); std::unique_ptr input_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + client_->TransferToServer(input_literal).ConsumeValueOrDie(); auto input_activations = - Parameter(&builder, 0, input_literal->shape(), "input"); + Parameter(&builder, 0, input_literal.shape(), "input"); XlaComputation add = CreateScalarAddComputation(F32, &builder); Reduce(input_activations, ConstantR0(&builder, 0.0f), add, GetParam().reduce_dims); @@ -867,21 +857,17 @@ INSTANTIATE_TEST_CASE_P( BoundsLayout{{2, 300, 784}, {2, 1, 0}, {1}}, BoundsLayout{{2, 300, 784}, {2, 1, 0}, {0}})); -// TODO(b/64093391) Disabled on GPU due to an assertion failure when running -// IrEmitterUnnested::EmitInitializer() for the Reduce operator. Failed on -// 2017-07-26. -XLA_TEST_F(ReduceTest, DISABLED_ON_GPU(OperationOnConstantAsInitValue)) { +XLA_TEST_F(ReduceTest, OperationOnConstantAsInitValue) { XlaBuilder builder(TestName()); XlaComputation max_f32 = CreateScalarMaxComputation(F32, &builder); auto a = ConstantR0(&builder, 2.0f); auto a2 = Abs(a); - std::unique_ptr b_literal = - LiteralUtil::CreateR1({1.0f, 4.0f}); + Literal b_literal = LiteralUtil::CreateR1({1.0f, 4.0f}); std::unique_ptr b_data = - client_->TransferToServer(*b_literal).ConsumeValueOrDie(); - auto b = Parameter(&builder, 0, b_literal->shape(), "b"); + client_->TransferToServer(b_literal).ConsumeValueOrDie(); + auto b = Parameter(&builder, 0, b_literal.shape(), "b"); Reduce(b, a2, max_f32, {0}); ComputeAndCompareR0(&builder, 4.0f, {b_data.get()}); @@ -908,9 +894,9 @@ class ReduceInitializerTest : public ReduceTest { std::vector input_arr(num_elems, std::numeric_limits::lowest()); auto input_literal = LiteralUtil::CreateR1(input_arr); auto input_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); - Reduce(Parameter(&builder, 0, input_literal->shape(), "input"), init, - max_fn, {0}); + client_->TransferToServer(input_literal).ConsumeValueOrDie(); + Reduce(Parameter(&builder, 0, input_literal.shape(), "input"), init, max_fn, + {0}); ComputeAndCompareR0(&builder, initializer, {input_data.get()}); } @@ -956,13 +942,12 @@ XLA_TEST_F(ReduceTest, ReduceIdentity) { float operand[] = {42.0f}; float init = 58.5f; float expected = 42.0f; - std::unique_ptr input_literal = - LiteralUtil::CreateR1(operand); + Literal input_literal = LiteralUtil::CreateR1(operand); std::unique_ptr input_global_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); - std::unique_ptr input_literal2 = LiteralUtil::CreateR0(init); + client_->TransferToServer(input_literal).ConsumeValueOrDie(); + Literal input_literal2 = LiteralUtil::CreateR0(init); std::unique_ptr input_global_data2 = - client_->TransferToServer(*input_literal2).ConsumeValueOrDie(); + client_->TransferToServer(input_literal2).ConsumeValueOrDie(); ComputeAndCompareR0( &builder, expected, {input_global_data.get(), input_global_data2.get()}, ErrorSpec(0.0001)); diff --git a/tensorflow/compiler/xla/tests/reduce_window_test.cc b/tensorflow/compiler/xla/tests/reduce_window_test.cc index 60167619a4eb89b3275cc728300c41419ce80c60..22fe4a2670e2e0e1fedc45036a1ceec19f44e42e 100644 --- a/tensorflow/compiler/xla/tests/reduce_window_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_window_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/array4d.h" @@ -38,7 +39,6 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -57,7 +57,7 @@ class ReduceWindowTestBase : public ClientLibraryTestBase { public: ErrorSpec DefaultErrorSpec() const { if (use_bfloat16()) { - return ErrorSpec(1e-1, 5e-2); + return ErrorSpec(2e-1, 6e-2); } else { return ErrorSpec(1e-3, 1e-3); } @@ -70,10 +70,10 @@ class ReduceWindowTest : public ::testing::WithParamInterface, ReduceWindowTest() : builder_(TestName()) { set_use_bfloat16(GetParam()); } void ReduceWindowAdd(const XlaOp& input, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, + absl::Span window_dimensions, + absl::Span window_strides, Padding padding) { - auto init = CreateConstantFromLiteral(*LiteralUtil::CreateR0(0.0f), + auto init = CreateConstantFromLiteral(LiteralUtil::CreateR0(0.0f), &builder_); ReduceWindow(input, init, CreateScalarAddComputation(FloatType(), &builder_), @@ -81,8 +81,8 @@ class ReduceWindowTest : public ::testing::WithParamInterface, } void ReduceWindowMax(const XlaOp& input, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, + absl::Span window_dimensions, + absl::Span window_strides, Padding padding) { auto init = CreateConstantFromLiteral(LiteralUtil::MinValue(F32), &builder_); @@ -92,8 +92,8 @@ class ReduceWindowTest : public ::testing::WithParamInterface, } void ReduceWindowMin(const XlaOp& input, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, + absl::Span window_dimensions, + absl::Span window_strides, Padding padding) { auto init = CreateConstantFromLiteral(LiteralUtil::MaxValue(F32), &builder_); @@ -107,9 +107,9 @@ class ReduceWindowTest : public ::testing::WithParamInterface, TEST_P(ReduceWindowTest, MismatchedRanksGivesErrorStatus) { const auto input = CreateConstantFromLiteral( - *LiteralUtil::CreateR1({1, 1, 1, 1}), &builder_); + LiteralUtil::CreateR1({1, 1, 1, 1}), &builder_); const auto init_value = - CreateConstantFromLiteral(*LiteralUtil::CreateR0(0), &builder_); + CreateConstantFromLiteral(LiteralUtil::CreateR0(0), &builder_); TF_ASSERT_OK(builder_.first_error()); ReduceWindow(input, init_value, CreateScalarAddComputation(FloatType(), &builder_), @@ -124,31 +124,31 @@ TEST_P(ReduceWindowTest, MismatchedRanksGivesErrorStatus) { // Regression test for b/68964348. TEST_P(ReduceWindowTest, R0ReduceWindow) { const auto input = - CreateConstantFromLiteral(*LiteralUtil::CreateR0(42.0), &builder_); + CreateConstantFromLiteral(LiteralUtil::CreateR0(42.0), &builder_); const auto init = - CreateConstantFromLiteral(*LiteralUtil::CreateR0(1.0), &builder_); + CreateConstantFromLiteral(LiteralUtil::CreateR0(1.0), &builder_); ReduceWindow(input, init, CreateScalarAddComputation(FloatType(), &builder_), /*window_dimensions=*/{}, /*window_strides=*/{}, Padding::kSame); - ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateR0(43.0), {}, + ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateR0(43.0), {}, ErrorSpec(0.00001)); } TEST_P(ReduceWindowTest, Min3In5Stride2) { const auto input = CreateConstantFromLiteral( - *LiteralUtil::CreateR1({10000, 1000, 100, 10, 1}), &builder_); + LiteralUtil::CreateR1({10000, 1000, 100, 10, 1}), &builder_); ReduceWindowMin(input, {3}, {2}, Padding::kValid); - ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateR1({100, 1}), + ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateR1({100, 1}), {}, ErrorSpec(0.00001)); } TEST_P(ReduceWindowTest, Min3In5Stride1WithSamePadding) { const auto input = CreateConstantFromLiteral( - *LiteralUtil::CreateR1({10000, 1000, 100, 10, 1}), &builder_); + LiteralUtil::CreateR1({10000, 1000, 100, 10, 1}), &builder_); ReduceWindowMin(input, /*window_dimensions=*/{3}, /*window_strides=*/{1}, Padding::kSame); ComputeAndCompareLiteral(&builder_, - *LiteralUtil::CreateR1({1000, 100, 10, 1, 1}), + LiteralUtil::CreateR1({1000, 100, 10, 1, 1}), {}, ErrorSpec(0.00001)); } @@ -161,7 +161,7 @@ XLA_TEST_P(ReduceWindowTest, ZeroElementSmall) { auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 2, 1}, {1, 1, 1, 1}, padding); - ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res), {}, + ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res), {}, DefaultErrorSpec()); } @@ -176,7 +176,7 @@ TEST_P(ReduceWindowTest, NonSquareSmall) { auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 2, 1}, {1, 1, 1, 1}, padding); - ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res), {}, + ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res), {}, DefaultErrorSpec()); } @@ -190,7 +190,7 @@ TEST_P(ReduceWindowTest, MiddleDimsSmall) { auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 1, 1}, {1, 2, 2, 1}, padding); - ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res), {}, + ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res), {}, DefaultErrorSpec()); } @@ -207,7 +207,7 @@ TEST_P(ReduceWindowTest, Along2ndMinorDim) { auto res = ReferenceUtil::ReduceWindow4DAdd( input_array, 0.0f, {1, 1, lrn_diameter, 1}, {1, 1, 1, 1}, padding); - ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res), {}, + ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res), {}, DefaultErrorSpec()); } @@ -229,8 +229,8 @@ TEST_P(ReduceWindowTest, AmongMajor2Dims) { input_array, 0.0f, {win_len, win_len, 1, 1}, {win_stride, win_stride, 1, 1}, padding); - ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*result), - {}, DefaultErrorSpec()); + ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*result), {}, + DefaultErrorSpec()); } TEST_P(ReduceWindowTest, AmongMajor2DimsMediumSize) { @@ -252,8 +252,8 @@ TEST_P(ReduceWindowTest, AmongMajor2DimsMediumSize) { input_array, 0.0f, {win_len, win_len, 1, 1}, {win_stride, win_stride, 1, 1}, padding); - ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*result), - {}, DefaultErrorSpec()); + ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*result), {}, + DefaultErrorSpec()); } // Tests the super windowing logic w.r.t handling prime number of windows in a @@ -277,8 +277,8 @@ TEST_P(ReduceWindowTest, PrimeWindowsInReductionDimension) { input_array, 0.0f, {win_len, win_len, 1, 1}, {win_stride, win_stride, 1, 1}, padding); - ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*result), - {}, DefaultErrorSpec()); + ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*result), {}, + DefaultErrorSpec()); } TEST_P(ReduceWindowTest, ReduceAlongLaneDimension) { @@ -294,8 +294,8 @@ TEST_P(ReduceWindowTest, ReduceAlongLaneDimension) { auto result = ReferenceUtil::ReduceWindow4DAdd( input_array, 0.0f, {1, 1, 1, 11}, {1, 1, 1, 1}, padding); - ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*result), - {}, DefaultErrorSpec()); + ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*result), {}, + DefaultErrorSpec()); } // Tests a reduction function that is not a simple add/min/max/etc. @@ -313,12 +313,12 @@ XLA_TEST_P(ReduceWindowTest, NonstandardReduceFunction) { auto lhs = Parameter(b.get(), 0, scalar, "lhs"); auto rhs = Parameter(b.get(), 1, scalar, "rhs"); Min(Add(lhs, rhs), - CreateConstantFromLiteral(*LiteralUtil::CreateR0(8.0f), b.get())); + CreateConstantFromLiteral(LiteralUtil::CreateR0(8.0f), b.get())); XlaComputation reduce_fn = b->BuildAndNoteError(); ReduceWindow( input, - CreateConstantFromLiteral(*LiteralUtil::CreateR0(0.0f), &builder_), + CreateConstantFromLiteral(LiteralUtil::CreateR0(0.0f), &builder_), reduce_fn, /*window_dimensions=*/{1, 1, 2, 1}, /*window_strides=*/{1, 1, 1, 1}, padding); @@ -332,19 +332,18 @@ XLA_TEST_P(ReduceWindowTest, NonstandardReduceFunction) { /*window=*/{1, 1, 2, 1}, /*stride=*/{1, 1, 1, 1}, padding); - ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*expected), + ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*expected), {}, DefaultErrorSpec()); } TEST_P(ReduceWindowTest, R4UnitWindow) { Array4D input_array(13, 12, 8, 15); input_array.FillRandom(2.f, 2.f); - std::unique_ptr input_literal = - LiteralUtil::CreateR4FromArray4DWithLayout( - input_array, LayoutUtil::MakeLayout({0, 3, 2, 1})); + Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( + input_array, LayoutUtil::MakeLayout({0, 3, 2, 1})); XlaOp input; auto input_data = CreateParameterAndTransferLiteral( - 0, *input_literal, "parameter", &builder_, &input); + 0, input_literal, "parameter", &builder_, &input); Padding padding = Padding::kSame; ReduceWindowAdd(input, {1, 1, 7, 1}, {1, 4, 1, 1}, padding); @@ -352,7 +351,7 @@ TEST_P(ReduceWindowTest, R4UnitWindow) { auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 7, 1}, {1, 4, 1, 1}, padding); - ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res), + ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res), {input_data.get()}, DefaultErrorSpec()); } @@ -360,9 +359,9 @@ XLA_TEST_P(ReduceWindowTest, R6AddMultipleStrides) { std::vector input_dims(6, 8); auto shape = ShapeUtil::MakeShape(F32, input_dims); - auto arg_literal = absl::make_unique(shape); - arg_literal->PopulateWithValue(1.0f); - const auto input = CreateConstantFromLiteral(*arg_literal, &builder_); + Literal arg_literal(shape); + arg_literal.PopulateWithValue(1.0f); + const auto input = CreateConstantFromLiteral(arg_literal, &builder_); Padding padding = Padding::kValid; ReduceWindowAdd(input, {3, 1, 3, 3, 1, 1}, {1, 1, 1, 1, 1, 1}, padding); @@ -371,39 +370,38 @@ XLA_TEST_P(ReduceWindowTest, R6AddMultipleStrides) { std::vector output_dims = {6, 8, 6, 6, 8, 8}; Shape result_shape = ShapeUtil::MakeShapeWithLayout(F32, output_dims, output_layout); - auto expected = absl::make_unique(result_shape); - expected->PopulateWithValue(27.0f); - ComputeAndCompareLiteral(&builder_, *expected, {}, DefaultErrorSpec()); + Literal expected(result_shape); + expected.PopulateWithValue(27.0f); + ComputeAndCompareLiteral(&builder_, expected, {}, DefaultErrorSpec()); } XLA_TEST_P(ReduceWindowTest, R6Add) { std::vector input_dims(6, 8); auto shape = ShapeUtil::MakeShape(F32, input_dims); - std::unique_ptr arg_literal = + Literal arg_literal = LiteralUtil::CreateFullWithDescendingLayout(input_dims, 1.0f); - const auto input = CreateConstantFromLiteral(*arg_literal, &builder_); + const auto input = CreateConstantFromLiteral(arg_literal, &builder_); Padding padding = Padding::kValid; ReduceWindowAdd(input, {1, 1, 3, 3, 1, 1}, {1, 1, 1, 1, 1, 1}, padding); std::vector output_dims = {8, 8, 6, 6, 8, 8}; - std::unique_ptr expected = + Literal expected = LiteralUtil::CreateFullWithDescendingLayout(output_dims, 9.0f); - ComputeAndCompareLiteral(&builder_, *expected, {}, DefaultErrorSpec()); + ComputeAndCompareLiteral(&builder_, expected, {}, DefaultErrorSpec()); } XLA_TEST_P(ReduceWindowTest, R4SecondMinorStride) { Array4D input_array(2, 1, 27, 119); input_array.FillRandom(2.0f); - std::unique_ptr input_literal = - LiteralUtil::CreateR4FromArray4DWithLayout( - input_array, LayoutUtil::MakeLayout({3, 2, 1, 0})); + Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( + input_array, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaOp input; auto input_data = CreateParameterAndTransferLiteral( - 0, *input_literal, "parameter", &builder_, &input); + 0, input_literal, "parameter", &builder_, &input); int win_len = 1; int stride = 8; @@ -413,19 +411,18 @@ XLA_TEST_P(ReduceWindowTest, R4SecondMinorStride) { auto res = ReferenceUtil::ReduceWindow4DAdd( input_array, 0.0f, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding); - ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res), + ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res), {input_data.get()}, DefaultErrorSpec()); } XLA_TEST_P(ReduceWindowTest, R4SecondMinorUnitStride) { Array4D input_array(3, 2, 4, 64); input_array.FillRandom(2.0f); - std::unique_ptr input_literal = - LiteralUtil::CreateR4FromArray4DWithLayout( - input_array, LayoutUtil::MakeLayout({3, 2, 1, 0})); + Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( + input_array, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaOp input; auto input_data = CreateParameterAndTransferLiteral( - 0, *input_literal, "parameter", &builder_, &input); + 0, input_literal, "parameter", &builder_, &input); int win_len = 3; int stride = 1; @@ -435,19 +432,18 @@ XLA_TEST_P(ReduceWindowTest, R4SecondMinorUnitStride) { auto res = ReferenceUtil::ReduceWindow4DAdd( input_array, 0.0f, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding); - ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res), + ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res), {input_data.get()}, DefaultErrorSpec()); } XLA_TEST_P(ReduceWindowTest, R4SecondMinorWin) { Array4D input_array(1, 3, 12, 200); input_array.FillRandom(2.0f); - std::unique_ptr input_literal = - LiteralUtil::CreateR4FromArray4DWithLayout( - input_array, LayoutUtil::MakeLayout({3, 2, 1, 0})); + Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( + input_array, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaOp input; auto input_data = CreateParameterAndTransferLiteral( - 0, *input_literal, "parameter", &builder_, &input); + 0, input_literal, "parameter", &builder_, &input); int win_len = 8; int stride = 5; @@ -457,7 +453,7 @@ XLA_TEST_P(ReduceWindowTest, R4SecondMinorWin) { auto res = ReferenceUtil::ReduceWindow4DAdd( input_array, 0.0f, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding); - ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res), + ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res), {input_data.get()}, DefaultErrorSpec()); } @@ -478,18 +474,18 @@ TEST_P(ReduceWindowTest, AmongMajor2DimsMultipleMinor) { auto result = ReferenceUtil::ReduceWindow4DAdd( input_array, 0.0f, {win_len, win_len, 1, 1}, {win_stride, win_stride, 1, 1}, padding); - ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*result), - {}, DefaultErrorSpec()); + ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*result), {}, + DefaultErrorSpec()); } XLA_TEST_P(ReduceWindowTest, Add24In1152_NoOverlap) { std::vector input_vector(128 * 9, 1); const auto input = CreateConstantFromLiteral( - *LiteralUtil::CreateR1(input_vector), &builder_); + LiteralUtil::CreateR1(input_vector), &builder_); ReduceWindowAdd(input, {32}, {128}, Padding::kValid); ComputeAndCompareLiteral( &builder_, - *LiteralUtil::CreateR1({32, 32, 32, 32, 32, 32, 32, 32, 32}), {}, + LiteralUtil::CreateR1({32, 32, 32, 32, 32, 32, 32, 32, 32}), {}, DefaultErrorSpec()); } @@ -504,9 +500,9 @@ XLA_TEST_P(ReduceWindowTest, Add128In128Stride128) { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; const auto input = CreateConstantFromLiteral( - *LiteralUtil::CreateR1(input_vector), &builder_); + LiteralUtil::CreateR1(input_vector), &builder_); ReduceWindowAdd(input, {128}, {128}, Padding::kValid); - ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateR1({1088}), {}, + ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateR1({1088}), {}, DefaultErrorSpec()); } @@ -521,9 +517,9 @@ XLA_TEST_P(ReduceWindowTest, Add128In128) { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; const auto input = CreateConstantFromLiteral( - *LiteralUtil::CreateR1(input_vector), &builder_); + LiteralUtil::CreateR1(input_vector), &builder_); ReduceWindowAdd(input, {128}, {1}, Padding::kValid); - ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateR1({1088}), {}, + ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateR1({1088}), {}, DefaultErrorSpec()); } @@ -540,9 +536,8 @@ TEST_P(ReduceWindowTest, R2ReduceWindowInceptionFromBroadcast) { auto res = ReferenceUtil::ReduceWindow2DAdd( input_array, 0.0f, {win_len, win_len}, {stride, stride}, padding); - ComputeAndCompareLiteral(&builder_, - *LiteralUtil::CreateFromArray(*res), {}, - DefaultErrorSpec()); + ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res), + {}, DefaultErrorSpec()); } TEST_P(ReduceWindowTest, R2ReduceWindowNonOverlappingFromBroadcast) { @@ -556,9 +551,8 @@ TEST_P(ReduceWindowTest, R2ReduceWindowNonOverlappingFromBroadcast) { auto res = ReferenceUtil::ReduceWindow2DAdd(input_array, 0.0f, {4, 2}, {3, 3}, padding); - ComputeAndCompareLiteral(&builder_, - *LiteralUtil::CreateFromArray(*res), {}, - DefaultErrorSpec()); + ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res), + {}, DefaultErrorSpec()); } INSTANTIATE_TEST_CASE_P(ReduceWindowTestInstance, ReduceWindowTest, @@ -594,7 +588,7 @@ string R4ReduceWindowTestDataToString( // Test names are not allowed to contain the '-' character. std::replace(str.begin(), str.end(), '-', 'n'); if (::testing::get<1>(data.param)) { - str = absl::StrCat(str, "_bfloat16"); + absl::StrAppend(&str, "_bfloat16"); } return str; } @@ -613,12 +607,11 @@ class R4ReduceWindowTest : public ReduceWindowTestBase, Array4D input(param.base_bounds[0], param.base_bounds[1], param.base_bounds[2], param.base_bounds[3]); - input.FillIota(1); - std::unique_ptr input_literal = - LiteralUtil::CreateR4FromArray4DWithLayout( - input, LayoutUtil::MakeLayout(param.layout)); + input.FillRandom(0.1f, 0.1f); + Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( + input, LayoutUtil::MakeLayout(param.layout)); XlaOp parameter; - auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0", + auto input_arg = CreateParameterAndTransferLiteral(0, input_literal, "p0", &b, ¶meter); std::vector> padding(4); @@ -627,9 +620,16 @@ class R4ReduceWindowTest : public ReduceWindowTestBase, } auto init_value = - CreateConstantFromLiteral(*LiteralUtil::CreateR0(kInitValue), &b); + CreateConstantFromLiteral(LiteralUtil::CreateR0(kInitValue), &b); CHECK(param.reducer == kAdd || param.reducer == kMax); - auto computation = param.reducer == kAdd + auto reducer = param.reducer; + if (use_bfloat16() && Product(param.window_bounds) > 128) { + // To avoid numerical issues, force the reducer to be kMax for large bf16 + // windows. + reducer = kMax; + } + + auto computation = reducer == kAdd ? CreateScalarAddComputation(FloatType(), &b) : CreateScalarMaxComputation(FloatType(), &b); ReduceWindowWithGeneralPadding( @@ -638,10 +638,12 @@ class R4ReduceWindowTest : public ReduceWindowTestBase, /*computation=*/computation, /*window_dimensions=*/param.window_bounds, /*window_strides=*/param.strides, + /*base_dilations=*/{}, + /*window_dilations=*/{}, /*padding=*/padding); - CHECK(param.reducer == kAdd || param.reducer == kMax); - auto reduce_func = param.reducer == kAdd + CHECK(reducer == kAdd || reducer == kMax); + auto reduce_func = reducer == kAdd ? +[](float a, float b) { return a + b; } : +[](float a, float b) { return std::max(a, b); }; std::unique_ptr> expected = @@ -652,12 +654,11 @@ class R4ReduceWindowTest : public ReduceWindowTestBase, /*window=*/param.window_bounds, /*stride=*/param.strides, /*padding=*/padding); - std::unique_ptr expected_literal = - LiteralUtil::CreateFromArray(*expected); + Literal expected_literal = LiteralUtil::CreateFromArray(*expected); const Shape& expected_shape_with_layout = ShapeUtil::MakeShapeWithLayout( - input_literal->shape().element_type(), - AsInt64Slice(expected_literal->shape().dimensions()), param.layout); - ComputeAndCompareLiteral(&b, *expected_literal, {input_arg.get()}, + input_literal.shape().element_type(), + AsInt64Slice(expected_literal.shape().dimensions()), param.layout); + ComputeAndCompareLiteral(&b, expected_literal, {input_arg.get()}, DefaultErrorSpec(), &expected_shape_with_layout); } }; @@ -809,6 +810,22 @@ const R4ReduceWindowTestData kR4ReduceWindowTestValues[] = { /*pad_high=*/{1, 0, 0, 0}, /*layout=*/{3, 2, 1, 0}, /*reducer=*/kAdd}, + + R4ReduceWindowTestData{/*base_bounds=*/{8, 256, 256, 3}, + /*window_bounds=*/{1, 64, 64, 1}, + /*strides=*/{1, 64, 64, 1}, + /*pad_low=*/{0, 0, 0, 0}, + /*pad_high=*/{0, 0, 0, 0}, + /*layout=*/{3, 0, 2, 1}, + /*reducer=*/kAdd}, + + R4ReduceWindowTestData{/*base_bounds=*/{112, 112, 8, 64}, + /*window_bounds=*/{112, 112, 1, 8}, + /*strides=*/{112, 112, 1, 8}, + /*pad_low=*/{0, 0, 0, 0}, + /*pad_high=*/{0, 0, 0, 0}, + /*layout=*/{3, 2, 1, 0}, + /*reducer=*/kAdd}, }; INSTANTIATE_TEST_CASE_P( @@ -930,6 +947,27 @@ struct R3ReduceWindowTestData { {/*base_bounds=*/{6, 21, 3}, /*window_bounds=*/{2, 3, 2}, /*strides=*/{1, 2, 2}, /*layout=*/{1, 0, 2}, /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, + {/*base_bounds=*/{95, 202, 251}, /*window_bounds=*/{95, 202, 251}, + /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0}, + /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, + {/*base_bounds=*/{999, 57, 3}, /*window_bounds=*/{999, 57, 3}, + /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0}, + /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, + {/*base_bounds=*/{178, 302, 64}, /*window_bounds=*/{178, 302, 64}, + /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0}, + /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, + {/*base_bounds=*/{63, 261, 257}, /*window_bounds=*/{63, 261, 257}, + /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0}, + /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, + {/*base_bounds=*/{10003, 10, 5}, /*window_bounds=*/{9999, 7, 3}, + /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0}, + /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, + {/*base_bounds=*/{9999, 1, 1}, /*window_bounds=*/{9999, 1, 1}, + /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0}, + /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, + {/*base_bounds=*/{10003, 10, 5}, /*window_bounds=*/{9999, 7, 3}, + /*strides=*/{2, 2, 2}, /*layout=*/{2, 1, 0}, + /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, }; string R3ReduceWindowTestDataToString( @@ -944,7 +982,7 @@ string R3ReduceWindowTestDataToString( param.layout[0], "_", param.layout[1], "_", param.layout[2], "__reducer_", param.reducer == kAdd ? "add" : "max"); if (::testing::get<1>(data.param)) { - str = absl::StrCat(str, "_bfloat16"); + absl::StrAppend(&str, "_bfloat16"); } return str; } @@ -956,35 +994,41 @@ class R3ReduceWindowTest : public ReduceWindowTestBase, R3ReduceWindowTest() { set_use_bfloat16(::testing::get<1>(GetParam())); } }; -TEST_P(R3ReduceWindowTest, Add) { +TEST_P(R3ReduceWindowTest, DoIt) { XlaBuilder b(TestName()); const auto& param = ::testing::get<0>(GetParam()); - CHECK(param.reducer == kAdd); const float kInitValue = 0.0f; Array3D input(param.base_bounds[0], param.base_bounds[1], - param.base_bounds[2], 1.0f); - std::unique_ptr input_literal = - LiteralUtil::CreateR3FromArray3DWithLayout( - input, LayoutUtil::MakeLayout(param.layout)); + param.base_bounds[2]); + input.FillRandom(0.1f, 0.1f); + Literal input_literal = LiteralUtil::CreateR3FromArray3DWithLayout( + input, LayoutUtil::MakeLayout(param.layout)); + auto reducer = param.reducer; + if (use_bfloat16()) { + input_literal = LiteralUtil::ConvertF32ToBF16(input_literal); + if (Product(param.window_bounds) > 128) { + // To avoid numerical issues, force the reducer to be kMax for large bf16 + // windows. + reducer = kMax; + } + } - XlaOp parameter; - auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0", - &b, ¶meter); + XlaOp parameter = Parameter(&b, 0, input_literal.shape(), "input"); auto init_value = - CreateConstantFromLiteral(*LiteralUtil::CreateR0(kInitValue), &b); + CreateConstantFromLiteral(LiteralUtil::CreateR0(kInitValue), &b); + + auto computation = reducer == kAdd + ? CreateScalarAddComputation(FloatType(), &b) + : CreateScalarMaxComputation(FloatType(), &b); + ReduceWindow(/*operand=*/parameter, /*init_value=*/init_value, - /*computation=*/CreateScalarAddComputation(FloatType(), &b), + /*computation=*/computation, /*window_dimensions=*/param.window_bounds, /*window_strides=*/param.strides, /*padding=*/param.padding); - auto expected = ReferenceUtil::ReduceWindow3DAdd( - /*operand=*/input, /*init=*/kInitValue, /*window=*/param.window_bounds, - /*stride=*/param.strides, /*padding=*/param.padding); - - ComputeAndCompareLiteral(&b, *LiteralUtil::CreateFromArray(*expected), - {input_arg.get()}, DefaultErrorSpec()); + ComputeAndCompare(&b, {std::move(input_literal)}, DefaultErrorSpec()); } INSTANTIATE_TEST_CASE_P( @@ -1079,7 +1123,7 @@ string R2ReduceWindowTestDataToString( param.layout[1], // "__reducer_", param.reducer == kAdd ? "add" : "max"); if (::testing::get<1>(data.param)) { - str = absl::StrCat(str, "_bfloat16"); + absl::StrAppend(&str, "_bfloat16"); } return str; } @@ -1093,16 +1137,14 @@ class R2ReduceWindowTest : public ReduceWindowTestBase, void DoIt() { XlaBuilder b(TestName()); const auto& param = ::testing::get<0>(GetParam()); - CHECK(param.reducer == kAdd); const float kInitValue = 0.0f; Array2D input(param.base_bounds[0], param.base_bounds[1], 1.0f); - std::unique_ptr input_literal = - LiteralUtil::CreateR2FromArray2DWithLayout( - input, LayoutUtil::MakeLayout(param.layout)); + Literal input_literal = LiteralUtil::CreateR2FromArray2DWithLayout( + input, LayoutUtil::MakeLayout(param.layout)); XlaOp parameter; - auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0", + auto input_arg = CreateParameterAndTransferLiteral(0, input_literal, "p0", &b, ¶meter); std::vector> padding(2); for (int i = 0; i < 2; ++i) { @@ -1112,13 +1154,16 @@ class R2ReduceWindowTest : public ReduceWindowTestBase, ? CreateScalarAddComputation(FloatType(), &b) : CreateScalarMaxComputation(FloatType(), &b); auto init_value = - CreateConstantFromLiteral(*LiteralUtil::CreateR0(kInitValue), &b); + CreateConstantFromLiteral(LiteralUtil::CreateR0(kInitValue), &b); ReduceWindowWithGeneralPadding( /*operand=*/parameter, /*init_value=*/init_value, /*computation=*/computation, /*window_dimensions=*/param.window_bounds, - /*window_strides=*/param.strides, /*padding=*/padding); + /*window_strides=*/param.strides, + /*base_dilations=*/{}, + /*window_dilations=*/{}, + /*padding=*/padding); auto reduce_func = param.reducer == kAdd ? +[](float a, float b) { return a + b; } @@ -1128,7 +1173,7 @@ class R2ReduceWindowTest : public ReduceWindowTestBase, /*window=*/param.window_bounds, /*stride=*/param.strides, /*padding=*/padding); - ComputeAndCompareLiteral(&b, *LiteralUtil::CreateFromArray(*expected), + ComputeAndCompareLiteral(&b, LiteralUtil::CreateFromArray(*expected), {input_arg.get()}, DefaultErrorSpec()); } }; @@ -1263,11 +1308,19 @@ struct R1ReduceWindowTestData { /*pad_high=*/{0}, /*reducer=*/Reducer::kAdd}, + // The pattern generated by inclusive scan (cumsum/cumprod). {/*base_bounds=*/{4096}, /*window_bounds=*/{4096}, /*strides=*/{1}, /*pad_low=*/{4095}, /*pad_high=*/{0}, /*reducer=*/Reducer::kMax}, + + // The pattern generated by exclusive scan (cumsum/cumprod). + {/*base_bounds=*/{4096}, /*window_bounds=*/{4096}, + /*strides=*/{1}, + /*pad_low=*/{4096}, + /*pad_high=*/{0}, + /*reducer=*/Reducer::kMax}, }; string R1ReduceWindowTestDataToString( @@ -1282,7 +1335,7 @@ string R1ReduceWindowTestDataToString( "__pad_high_", absl::StrJoin(param.pad_high, "x"), "__reducer_", param.reducer == kAdd ? "add" : "max"); if (::testing::get<1>(data.param)) { - str = absl::StrCat(str, "_bfloat16"); + absl::StrAppend(&str, "_bfloat16"); } return str; } @@ -1302,11 +1355,11 @@ TEST_P(R1ReduceWindowTest, DoIt) { const float kInitValue = 0.0f; std::vector input_vector(param.base_bounds[0]); std::iota(std::begin(input_vector), std::end(input_vector), 0); - std::unique_ptr input_literal = - LiteralUtil::CreateR1(tensorflow::gtl::ArraySlice(input_vector)); + Literal input_literal = + LiteralUtil::CreateR1(absl::Span(input_vector)); XlaOp parameter; - auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0", - &b, ¶meter); + auto input_arg = + CreateParameterAndTransferLiteral(0, input_literal, "p0", &b, ¶meter); std::vector> padding(1); padding[0] = {param.pad_low[0], param.pad_high[0]}; @@ -1315,26 +1368,29 @@ TEST_P(R1ReduceWindowTest, DoIt) { ? CreateScalarAddComputation(FloatType(), &b) : CreateScalarMaxComputation(FloatType(), &b); auto init_value = - CreateConstantFromLiteral(*LiteralUtil::CreateR0(kInitValue), &b); + CreateConstantFromLiteral(LiteralUtil::CreateR0(kInitValue), &b); ReduceWindowWithGeneralPadding( /*operand=*/parameter, /*init_value=*/init_value, /*computation=*/computation, /*window_dimensions=*/param.window_bounds, - /*window_strides=*/param.strides, /*padding=*/padding); + /*window_strides=*/param.strides, + /*base_dilations=*/{}, + /*window_dilations=*/{}, + /*padding=*/padding); auto reduce_func = param.reducer == kAdd ? +[](float a, float b) { return a + b; } : +[](float a, float b) { return std::max(a, b); }; auto expected = ReferenceUtil::ReduceWindow1DGeneric( - /*operand=*/tensorflow::gtl::ArraySlice(input_vector), + /*operand=*/absl::Span(input_vector), /*init=*/kInitValue, /*reduce_func=*/reduce_func, /*window=*/param.window_bounds, /*stride=*/param.strides, /*padding=*/padding); - ComputeAndCompareLiteral(&b, *LiteralUtil::CreateR1(*expected), + ComputeAndCompareLiteral(&b, LiteralUtil::CreateR1(*expected), {input_arg.get()}, DefaultErrorSpec()); } diff --git a/tensorflow/compiler/xla/tests/replay_test.cc b/tensorflow/compiler/xla/tests/replay_test.cc index d8914513819415368a628eab1f482f9644dd46b1..5cf87e565bf493167f5173588e7afa3b96282488 100644 --- a/tensorflow/compiler/xla/tests/replay_test.cc +++ b/tensorflow/compiler/xla/tests/replay_test.cc @@ -58,13 +58,13 @@ TEST_F(ReplayTest, TwoPlusTwoReplay) { ASSERT_TRUE(protobuf_util::ProtobufEquals(*original_shape, *replayed_shape)); // Run it. - std::unique_ptr literal = + Literal literal = client_ ->ExecuteAndTransfer(replayed, /*arguments=*/{}, &execution_options_) .ConsumeValueOrDie(); // Expect 4. - LiteralTestUtil::ExpectR0Equal(4, *literal); + LiteralTestUtil::ExpectR0Equal(4, literal); } XLA_TEST_F(ReplayTest, XPlusYReplayWithParameters) { @@ -91,12 +91,12 @@ XLA_TEST_F(ReplayTest, XPlusYReplayWithParameters) { // Run it. std::unique_ptr x_data = - client_->TransferToServer(*LiteralUtil::CreateR0(2)) + client_->TransferToServer(LiteralUtil::CreateR0(2)) .ConsumeValueOrDie(); std::unique_ptr y_data = - client_->TransferToServer(*LiteralUtil::CreateR0(3)) + client_->TransferToServer(LiteralUtil::CreateR0(3)) .ConsumeValueOrDie(); - std::unique_ptr literal = + Literal literal = client_ ->ExecuteAndTransfer(replayed, /*arguments=*/{x_data.get(), y_data.get()}, @@ -104,7 +104,7 @@ XLA_TEST_F(ReplayTest, XPlusYReplayWithParameters) { .ConsumeValueOrDie(); // Expect 5. - LiteralTestUtil::ExpectR0Equal(5, *literal); + LiteralTestUtil::ExpectR0Equal(5, literal); } TEST_F(ReplayTest, MapPlusTwoOverR1) { @@ -136,13 +136,13 @@ TEST_F(ReplayTest, MapPlusTwoOverR1) { ASSERT_TRUE(protobuf_util::ProtobufEquals(*original_shape, *replayed_shape)); // Run it. - std::unique_ptr literal = + Literal literal = client_ ->ExecuteAndTransfer(replayed, /*arguments=*/{}, &execution_options_) .ConsumeValueOrDie(); // Expect result. - LiteralTestUtil::ExpectR1Equal({3, 4, 5}, *literal); + LiteralTestUtil::ExpectR1Equal({3, 4, 5}, literal); } } // namespace diff --git a/tensorflow/compiler/xla/tests/reshape_motion_test.cc b/tensorflow/compiler/xla/tests/reshape_motion_test.cc index 368f5583c9ce3773e57b858ff7606f679346529a..ae24eb5eb4822a2057e34a1aec8b7d64604d8984 100644 --- a/tensorflow/compiler/xla/tests/reshape_motion_test.cc +++ b/tensorflow/compiler/xla/tests/reshape_motion_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/global_data.h" @@ -33,7 +34,6 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" diff --git a/tensorflow/compiler/xla/tests/reshape_test.cc b/tensorflow/compiler/xla/tests/reshape_test.cc index 382d1b1ae741285dcd1f7761edb82a5c333887af..dedc95b5ae8315185a35f786af42aad53bd7ad96 100644 --- a/tensorflow/compiler/xla/tests/reshape_test.cc +++ b/tensorflow/compiler/xla/tests/reshape_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/global_data.h" @@ -35,7 +36,6 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -57,12 +57,12 @@ XLA_TEST_P(ReshapeTest, CollapseTrivial1x1) { input_array.Fill(1.0f); auto input_literal = LiteralUtil::CreateR2FromArray2D(input_array); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "parameter", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "parameter", &builder, ¶meter); Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); auto expected_literal = LiteralUtil::CreateR1({1.0f}); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -70,12 +70,12 @@ XLA_TEST_P(ReshapeTest, CollapseTrivialR1EmptyDims) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateR1({1.0f}); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "parameter", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "parameter", &builder, ¶meter); Collapse(/*operand=*/parameter, /*dimensions=*/{}); auto expected_literal = LiteralUtil::CreateR1({1.0f}); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -83,12 +83,12 @@ XLA_TEST_P(ReshapeTest, CollapseTrivialR1OnlyDim) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateR1({1.0f}); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "parameter", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "parameter", &builder, ¶meter); Collapse(/*operand=*/parameter, /*dimensions=*/{0}); auto expected_literal = LiteralUtil::CreateR1({1.0f}); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -99,29 +99,29 @@ XLA_TEST_P(ReshapeTest, SingleElementArrayToScalar) { input_array.Fill(1.0f); auto input_literal = LiteralUtil::CreateR2FromArray2D(input_array); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "parameter", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "parameter", &builder, ¶meter); auto reshape = Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, /*new_sizes=*/{}); auto new_shape = builder.GetShape(reshape).ConsumeValueOrDie(); auto expected_literal = LiteralUtil::CreateR0(1.0f); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } XLA_TEST_P(ReshapeTest, ScalarToSingleElementArray) { XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = LiteralUtil::CreateR0(1.0f); + Literal param0_literal = LiteralUtil::CreateR0(1.0f); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *param0_literal, "param0", + auto input = CreateParameterAndTransferLiteral(0, param0_literal, "param0", &builder, ¶meter); auto a = Neg(parameter); Reshape(/*operand=*/a, /*dimensions=*/{}, /*new_sizes=*/{1}); auto expected_literal = LiteralUtil::CreateR1({-1.0f}); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -130,25 +130,25 @@ XLA_TEST_P(ReshapeTest, Trivial0x3) { Array2D input_array(0, 3); auto input_literal = LiteralUtil::CreateR2FromArray2D(input_array); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); auto expected_literal = LiteralUtil::CreateR1({}); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } XLA_TEST_P(ReshapeTest, Trivial0x3WithParameter) { XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = + Literal param0_literal = LiteralUtil::CreateR2FromArray2D(Array2D(0, 3)); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *param0_literal, "param0", + auto input = CreateParameterAndTransferLiteral(0, param0_literal, "param0", &builder, ¶meter); Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); auto expected_literal = LiteralUtil::CreateR1({}); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -157,11 +157,11 @@ XLA_TEST_P(ReshapeTest, Trivial3x0) { Array2D input_array(3, 0); auto input_literal = LiteralUtil::CreateR2FromArray2D(input_array); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); auto expected_literal = LiteralUtil::CreateR1({}); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -170,11 +170,11 @@ XLA_TEST_P(ReshapeTest, Trivial1x3) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateR2({{1.0f, 2.0f, 3.0f}}); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); auto expected_literal = LiteralUtil::CreateR1({1.0f, 2.0f, 3.0f}); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -183,11 +183,11 @@ XLA_TEST_P(ReshapeTest, Trivial3x1) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateR2({{1.0f}, {2.0f}, {3.0f}}); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); auto expected_literal = LiteralUtil::CreateR1({1.0f, 2.0f, 3.0f}); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -196,12 +196,12 @@ XLA_TEST_P(ReshapeTest, R1ToR2_0_To_2x0) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateR1({}); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{0}, /*new_sizes=*/{2, 0}); auto expected_literal = LiteralUtil::CreateR2({{}, {}}); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -211,13 +211,13 @@ XLA_TEST_P(ReshapeTest, R1ToR2_6_To_2x3) { auto input_literal = LiteralUtil::CreateR1({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{0}, /*new_sizes=*/{2, 3}); auto expected_literal = LiteralUtil::CreateR2({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -226,12 +226,12 @@ XLA_TEST_P(ReshapeTest, Reshape0x2To2x0) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateFromArray(Array2D(0, 2)); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, /*new_sizes=*/{2, 0}); auto expected_literal = LiteralUtil::CreateR2({{}, {}}); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -241,14 +241,14 @@ XLA_TEST_P(ReshapeTest, ReshapeRowToCol) { auto simple = MakeLinspaceArray2D(1.0f, 3.0f, 1, 3); auto input_literal = LiteralUtil::CreateFromArray(*simple); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, /*new_sizes=*/{3, 1}); auto expected = ReferenceUtil::TransposeArray2D(*simple); auto expected_literal = LiteralUtil::CreateFromArray(*expected); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -258,14 +258,14 @@ XLA_TEST_P(ReshapeTest, TransposeAsReshape) { auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3); auto input_literal = LiteralUtil::CreateFromArray(*a4x3); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{1, 0}, /*new_sizes=*/{3, 4}); auto expected = ReferenceUtil::TransposeArray2D(*a4x3); auto expected_literal = LiteralUtil::CreateFromArray(*expected); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -274,11 +274,11 @@ XLA_TEST_P(ReshapeTest, Transpose0x4) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateFromArray(Array2D(0, 4)); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Transpose(parameter, {1, 0}); auto expected_literal = LiteralUtil::CreateR2({{}, {}, {}, {}}); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -288,13 +288,13 @@ XLA_TEST_P(ReshapeTest, Transpose4x3) { auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3); auto input_literal = LiteralUtil::CreateFromArray(*a4x3); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Transpose(parameter, {1, 0}); auto expected = ReferenceUtil::TransposeArray2D(*a4x3); auto expected_literal = LiteralUtil::CreateFromArray(*expected); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -304,13 +304,13 @@ XLA_TEST_P(ReshapeTest, ReshapeSplitNoShuffleZeroElements) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateFromArray(Array2D(6, 0)); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, /*new_sizes=*/{2, 3, 0, 0}); auto expected_literal = LiteralUtil::CreateFromArray(Array4D(2, 3, 0, 0)); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -318,12 +318,12 @@ XLA_TEST_P(ReshapeTest, ReshapeR4ToR2ZeroElements) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateFromArray(Array4D(2, 3, 4, 0)); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{24, 0}); auto expected_literal = LiteralUtil::CreateFromArray(Array2D(24, 0)); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -334,14 +334,14 @@ XLA_TEST_P(ReshapeTest, ReshapeSplitNoShuffle) { auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3); auto input_literal = LiteralUtil::CreateFromArray(*a4x3); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, /*new_sizes=*/{2, 6}); auto expected = MakeLinspaceArray2D(1.0f, 12.0f, 2, 6); auto expected_literal = LiteralUtil::CreateFromArray(*expected); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -349,12 +349,12 @@ XLA_TEST_P(ReshapeTest, ReshapeSplitAndShuffleZeroElements) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateFromArray(Array2D(0, 6)); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{1, 0}, /*new_sizes=*/{3, 0}); auto expected_literal = LiteralUtil::CreateFromArray(Array2D(3, 0)); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -365,14 +365,14 @@ XLA_TEST_P(ReshapeTest, ReshapeSplitAndShuffle) { auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3); auto input_literal = LiteralUtil::CreateFromArray(*a4x3); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{1, 0}, /*new_sizes=*/{2, 6}); Array2D expected({{1.0f, 4.0f, 7.0f, 10.0f, 2.0f, 5.0f}, {8.0f, 11.0f, 3.0f, 6.0f, 9.0f, 12.0f}}); auto expected_literal = LiteralUtil::CreateFromArray(expected); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -391,14 +391,14 @@ XLA_TEST_P(ReshapeTest, DocR3_R1_Collapse_012) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateFromArray(ArrayForDocR3Tests()); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2}, /*new_sizes=*/{24}); auto expected_literal = LiteralUtil::CreateR1( {10, 11, 12, 15, 16, 17, 20, 21, 22, 25, 26, 27, 30, 31, 32, 35, 36, 37, 40, 41, 42, 45, 46, 47}); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -406,7 +406,7 @@ XLA_TEST_P(ReshapeTest, DocR3_R2_Collapse_012_Refine_83) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateFromArray(ArrayForDocR3Tests()); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2}, /*new_sizes=*/{8, 3}); @@ -418,7 +418,7 @@ XLA_TEST_P(ReshapeTest, DocR3_R2_Collapse_012_Refine_83) { {35, 36, 37}, {40, 41, 42}, {45, 46, 47}}); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -426,14 +426,14 @@ XLA_TEST_P(ReshapeTest, DocR3_R1_Collapse_120) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateFromArray(ArrayForDocR3Tests()); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{1, 2, 0}, /*new_sizes=*/{24}); auto expected_literal = LiteralUtil::CreateR1( {10, 20, 30, 40, 11, 21, 31, 41, 12, 22, 32, 42, 15, 25, 35, 45, 16, 26, 36, 46, 17, 27, 37, 47}); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -441,7 +441,7 @@ XLA_TEST_P(ReshapeTest, DocR3_R2_Collapse_120_Refine_83) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateFromArray(ArrayForDocR3Tests()); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{1, 2, 0}, /*new_sizes=*/{8, 3}); @@ -453,7 +453,7 @@ XLA_TEST_P(ReshapeTest, DocR3_R2_Collapse_120_Refine_83) { {45, 16, 26}, {36, 46, 17}, {27, 37, 47}}); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -461,14 +461,14 @@ XLA_TEST_P(ReshapeTest, DocR3_R3_Collapse_120_Refine_262) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateFromArray(ArrayForDocR3Tests()); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{1, 2, 0}, /*new_sizes=*/{2, 6, 2}); auto expected_literal = LiteralUtil::CreateR3( {{{10, 20}, {30, 40}, {11, 21}, {31, 41}, {12, 22}, {32, 42}}, {{15, 25}, {35, 45}, {16, 26}, {36, 46}, {17, 27}, {37, 47}}}); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -494,14 +494,14 @@ XLA_TEST_P(ReshapeTest, FullyConnectedCollapse) { t2x2x2x3.FillWithYX(*filler2x3); auto input_literal = LiteralUtil::CreateFromArray(t2x2x2x3); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Collapse(/*operand=*/parameter, /*dimensions=*/{1, 2, 3}); auto expected_literal = LiteralUtil::CreateR2( {{1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}}); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -519,14 +519,14 @@ XLA_TEST_P(ReshapeTest, FullyConnectedCollapseDesugared) { t(1, 0, 1, 1) = 7; auto input_literal = LiteralUtil::CreateFromArray(t); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{2, 4}); auto expected_literal = LiteralUtil::CreateR2({{0, 1, 2, 3}, {4, 5, 6, 7}}); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -547,7 +547,7 @@ XLA_TEST_P(ReshapeTest, ToScalar) { Reshape(parameter, dimensions, {}); auto expected_literal = LiteralUtil::CreateR0(83.0f); - ComputeAndCompareLiteral(&b, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&b, expected_literal, {input.get()}, zero_error_spec_); } } @@ -556,7 +556,7 @@ XLA_TEST_P(ReshapeTest, BadDimensions) { XlaBuilder b(TestName()); auto input_literal = LiteralUtil::CreateR1({1.0f}); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &b, + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &b, ¶meter); Reshape(parameter, {}, {}); EXPECT_THAT( @@ -568,7 +568,7 @@ XLA_TEST_P(ReshapeTest, BadNewSizes) { XlaBuilder b(TestName()); auto input_literal = LiteralUtil::CreateR1({1.0f, 2.0f}); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &b, + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &b, ¶meter); Reshape(parameter, {1}, {}); EXPECT_THAT(ExecuteToString(&b, {}), @@ -604,7 +604,7 @@ XLA_TEST_P(ReshapeTest, R4Dim0MinorLayoutToR2Dim0MajorLayout) { LayoutUtil::MakeLayout({0, 1, 2, 3})); // clang-format on XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{2, 8}); @@ -619,27 +619,26 @@ XLA_TEST_P(ReshapeTest, R4Dim0MinorLayoutToR2Dim0MajorLayout) { *execution_options.mutable_shape_with_output_layout() = ShapeUtil::MakeShapeWithLayout(use_bfloat16() ? BF16 : F32, {2, 8}, {1, 0}); - std::unique_ptr actual = + Literal actual = client_ ->ExecuteAndTransfer(computation, {input.get()}, &execution_options) .ConsumeValueOrDie(); - std::unique_ptr expected = - LiteralUtil::CreateR2FromArray2D(expected_array); + Literal expected = LiteralUtil::CreateR2FromArray2D(expected_array); if (use_bfloat16()) { - expected = LiteralUtil::ConvertF32ToBF16(*expected); + expected = LiteralUtil::ConvertF32ToBF16(expected); } - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *actual)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, actual)); } XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4) { XlaBuilder builder(TestName()); - std::unique_ptr input_literal = LiteralUtil::CreateR2({ + Literal input_literal = LiteralUtil::CreateR2({ {0, 1, 2, 3, 4, 5, 6, 7}, {100, 101, 102, 103, 104, 105, 106, 107}, {200, 201, 202, 203, 204, 205, 206, 207}, }); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Reshape(parameter, /*dimensions=*/{0, 1}, /*new_sizes=*/{3, 2, 1, 4}); @@ -653,20 +652,20 @@ XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4) { {{204, 205, 206, 207}}} }); // clang-format on - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } // Tests R2->R4 reshape with the reshape dimensions {1, 0}. XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4_Dimensions_10) { XlaBuilder builder(TestName()); - std::unique_ptr input_literal = LiteralUtil::CreateR2({ + Literal input_literal = LiteralUtil::CreateR2({ {0, 1, 2, 3, 4, 5, 6, 7}, {100, 101, 102, 103, 104, 105, 106, 107}, {200, 201, 202, 203, 204, 205, 206, 207}, }); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Reshape(parameter, /*dimensions=*/{1, 0}, /*new_sizes=*/{3, 2, 1, 4}); @@ -680,7 +679,7 @@ XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4_Dimensions_10) { {{206, 7, 107, 207}}} }); // clang-format on - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -689,20 +688,17 @@ XLA_TEST_P(ReshapeTest, R4ToR2_2x1x1x1_To_2x1) { std::mt19937 rng; std::uniform_real_distribution distribution; Array4D input(2, 1, 1, 1); - input.Each( - [&rng, &distribution](tensorflow::gtl::ArraySlice /* indices */, - float* cell) { *cell = distribution(rng); }); - std::unique_ptr input_literal = - LiteralUtil::CreateR4FromArray4DWithLayout( - input, LayoutUtil::MakeLayout({3, 2, 1, 0})); + input.Each([&rng, &distribution](absl::Span /* indices */, + float* cell) { *cell = distribution(rng); }); + Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( + input, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaOp parameter; - auto input_data = CreateParameterAndTransferLiteral( - 0, *input_literal, "input", &builder, ¶meter); + auto input_data = CreateParameterAndTransferLiteral(0, input_literal, "input", + &builder, ¶meter); Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{2, 1}); - std::unique_ptr expected = - LiteralUtil::ReshapeSlice({2, 1}, {1, 0}, *input_literal); - ComputeAndCompareLiteral(&builder, *expected, {input_data.get()}, + Literal expected = LiteralUtil::ReshapeSlice({2, 1}, {1, 0}, input_literal); + ComputeAndCompareLiteral(&builder, expected, {input_data.get()}, zero_error_spec_); } @@ -711,20 +707,17 @@ XLA_TEST_P(ReshapeTest, R4ToR2_2x1x4x1_To_4x2) { std::mt19937 rng; std::uniform_real_distribution distribution; Array4D input(2, 1, 4, 1); - input.Each( - [&rng, &distribution](tensorflow::gtl::ArraySlice /* indices */, - float* cell) { *cell = distribution(rng); }); - std::unique_ptr input_literal = - LiteralUtil::CreateR4FromArray4DWithLayout( - input, LayoutUtil::MakeLayout({3, 2, 1, 0})); + input.Each([&rng, &distribution](absl::Span /* indices */, + float* cell) { *cell = distribution(rng); }); + Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( + input, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaOp parameter; - auto input_data = CreateParameterAndTransferLiteral( - 0, *input_literal, "input", &builder, ¶meter); + auto input_data = CreateParameterAndTransferLiteral(0, input_literal, "input", + &builder, ¶meter); Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{4, 2}); - std::unique_ptr expected = - LiteralUtil::ReshapeSlice({4, 2}, {1, 0}, *input_literal); - ComputeAndCompareLiteral(&builder, *expected, {input_data.get()}, + Literal expected = LiteralUtil::ReshapeSlice({4, 2}, {1, 0}, input_literal); + ComputeAndCompareLiteral(&builder, expected, {input_data.get()}, zero_error_spec_); } @@ -734,25 +727,23 @@ XLA_TEST_P(ReshapeTest, R4ToR2_5x10x2x3_To_5x60_Dimensions_0213) { std::mt19937 rng; std::uniform_real_distribution distribution; Array4D input(5, 10, 2, 3); - input.Each( - [&rng, &distribution](tensorflow::gtl::ArraySlice /* indices */, - float* cell) { *cell = distribution(rng); }); - std::unique_ptr input_literal = - LiteralUtil::CreateR4FromArray4DWithLayout( - input, LayoutUtil::MakeLayout({3, 2, 1, 0})); + input.Each([&rng, &distribution](absl::Span /* indices */, + float* cell) { *cell = distribution(rng); }); + Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( + input, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaOp parameter; - auto input_data = CreateParameterAndTransferLiteral( - 0, *input_literal, "input", &builder, ¶meter); + auto input_data = CreateParameterAndTransferLiteral(0, input_literal, "input", + &builder, ¶meter); Reshape(parameter, /*dimensions=*/{0, 2, 1, 3}, /*new_sizes=*/{5, 60}); Array2D expected_array(5, 60); - input.Each([&](tensorflow::gtl::ArraySlice indices, float* cell) { + input.Each([&](absl::Span indices, float* cell) { expected_array(indices[0], indices[2] * 30 + indices[1] * 3 + indices[3]) = *cell; }); auto expected = LiteralUtil::CreateR2FromArray2D(expected_array); - ComputeAndCompareLiteral(&builder, *expected, {input_data.get()}, + ComputeAndCompareLiteral(&builder, expected, {input_data.get()}, zero_error_spec_); } @@ -762,14 +753,13 @@ XLA_TEST_P(ReshapeTest, NoopReshape) { std::uniform_real_distribution distribution; Array4D input_array(2, 3, 5, 7); input_array.Each( - [&rng, &distribution](tensorflow::gtl::ArraySlice /* indices */, + [&rng, &distribution](absl::Span /* indices */, float* cell) { *cell = distribution(rng); }); - std::unique_ptr input_literal = - LiteralUtil::CreateR4FromArray4DWithLayout( - input_array, LayoutUtil::MakeLayout({1, 2, 3, 0})); + Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( + input_array, LayoutUtil::MakeLayout({1, 2, 3, 0})); XlaOp parameter; - auto input_data = CreateParameterAndTransferLiteral( - 0, *input_literal, "input", &builder, ¶meter); + auto input_data = CreateParameterAndTransferLiteral(0, input_literal, "input", + &builder, ¶meter); Reshape(parameter, /*dimensions=*/{3, 0, 1, 2}, /*new_sizes=*/{7, 2, 3, 5}); XlaComputation computation = builder.Build().ConsumeValueOrDie(); @@ -778,7 +768,7 @@ XLA_TEST_P(ReshapeTest, NoopReshape) { *execution_options.mutable_shape_with_output_layout() = ShapeUtil::MakeShapeWithLayout(use_bfloat16() ? BF16 : F32, {7, 2, 3, 5}, {2, 3, 0, 1}); - std::unique_ptr output_literal = + Literal output_literal = client_ ->ExecuteAndTransfer(computation, {input_data.get()}, &execution_options) @@ -787,10 +777,10 @@ XLA_TEST_P(ReshapeTest, NoopReshape) { // Since the reshape is a no-op, verify that it does not change the underlying // data. if (use_bfloat16()) { - auto expected = LiteralUtil::ConvertF32ToBF16(*input_literal); - EXPECT_EQ(expected->data(), output_literal->data()); + auto expected = LiteralUtil::ConvertF32ToBF16(input_literal); + EXPECT_EQ(expected.data(), output_literal.data()); } else { - EXPECT_EQ(input_literal->data(), output_literal->data()); + EXPECT_EQ(input_literal.data(), output_literal.data()); } } @@ -801,12 +791,12 @@ XLA_TEST_P(ReshapeTest, R4ToR4Reshape_Trivial) { {{13, 14, 15, 16}, {17, 18, 19, 20}, {21, 22, 23, 24}}}}); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *literal_1x2x3x4, "input", + auto input = CreateParameterAndTransferLiteral(0, literal_1x2x3x4, "input", &builder, ¶meter); Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{1, 2, 3, 4}); - ComputeAndCompareLiteral(&builder, *literal_1x2x3x4, {input.get()}); + ComputeAndCompareLiteral(&builder, literal_1x2x3x4, {input.get()}); } XLA_TEST_P(ReshapeTest, R4ToR4Reshape) { @@ -816,7 +806,7 @@ XLA_TEST_P(ReshapeTest, R4ToR4Reshape) { XlaBuilder builder(TestName()); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *literal_1x2x3x4, "input", + auto input = CreateParameterAndTransferLiteral(0, literal_1x2x3x4, "input", &builder, ¶meter); Reshape(parameter, /*dimensions=*/{1, 3, 2, 0}, /*new_sizes=*/{2, 4, 3, 1}); @@ -833,7 +823,7 @@ XLA_TEST_P(ReshapeTest, R4ToR4Reshape) { {{16}, {20}, {24}}}}); // clang-format on - ComputeAndCompareLiteral(&builder, *expected_2x4x3x1, {input.get()}); + ComputeAndCompareLiteral(&builder, expected_2x4x3x1, {input.get()}); } XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeSimple) { @@ -842,27 +832,25 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeSimple) { std::vector bounds = {2, 2, 2, 2}; std::vector new_bounds = {bounds[0], bounds[1], bounds[3], bounds[2]}; Array4D input(bounds[0], bounds[1], bounds[2], bounds[3]); - input.Each( - [&rng, &distribution](tensorflow::gtl::ArraySlice /* indices */, - float* cell) { *cell = distribution(rng); }); - std::unique_ptr input_literal = - LiteralUtil::CreateR4FromArray4DWithLayout( - input, LayoutUtil::MakeLayout({3, 2, 1, 0})); + input.Each([&rng, &distribution](absl::Span /* indices */, + float* cell) { *cell = distribution(rng); }); + Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( + input, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaBuilder builder(TestName()); XlaOp parameter; - auto input_data = CreateParameterAndTransferLiteral( - 0, *input_literal, "input", &builder, ¶meter); + auto input_data = CreateParameterAndTransferLiteral(0, input_literal, "input", + &builder, ¶meter); Reshape(parameter, /*dimensions=*/{0, 1, 3, 2}, /*new_sizes=*/new_bounds); - std::unique_ptr expected = - LiteralUtil::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal) - ->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0})); + Literal expected = + LiteralUtil::ReshapeSlice(new_bounds, {2, 3, 1, 0}, input_literal) + .Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0})); // Specify the requested output shape explicitly to ensure that this reshape // actually corresponds to a two minor transpose. - ComputeAndCompareLiteral(&builder, *expected, {input_data.get()}, - zero_error_spec_, &expected->shape()); + ComputeAndCompareLiteral(&builder, expected, {input_data.get()}, + zero_error_spec_, &expected.shape()); } XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstEffectiveR2) { @@ -871,27 +859,25 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstEffectiveR2) { std::vector bounds = {1, 1, 250, 300}; std::vector new_bounds = {bounds[0], bounds[1], bounds[3], bounds[2]}; Array4D input(bounds[0], bounds[1], bounds[2], bounds[3]); - input.Each( - [&rng, &distribution](tensorflow::gtl::ArraySlice /* indices */, - float* cell) { *cell = distribution(rng); }); - std::unique_ptr input_literal = - LiteralUtil::CreateR4FromArray4DWithLayout( - input, LayoutUtil::MakeLayout({3, 2, 1, 0})); + input.Each([&rng, &distribution](absl::Span /* indices */, + float* cell) { *cell = distribution(rng); }); + Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( + input, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaBuilder builder(TestName()); XlaOp parameter; - auto input_data = CreateParameterAndTransferLiteral( - 0, *input_literal, "input", &builder, ¶meter); + auto input_data = CreateParameterAndTransferLiteral(0, input_literal, "input", + &builder, ¶meter); Reshape(parameter, /*dimensions=*/{0, 1, 3, 2}, /*new_sizes=*/new_bounds); - std::unique_ptr expected = - LiteralUtil::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal) - ->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0})); + Literal expected = + LiteralUtil::ReshapeSlice(new_bounds, {2, 3, 1, 0}, input_literal) + .Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0})); // Specify the requested output shape explicitly to ensure that this reshape // actually corresponds to a two minor transpose. - ComputeAndCompareLiteral(&builder, *expected, {input_data.get()}, - zero_error_spec_, &expected->shape()); + ComputeAndCompareLiteral(&builder, expected, {input_data.get()}, + zero_error_spec_, &expected.shape()); } XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1) { @@ -900,27 +886,25 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1) { std::vector bounds = {5, 5, 1, 10}; std::vector new_bounds = {bounds[0], bounds[1], bounds[3], bounds[2]}; Array4D input(bounds[0], bounds[1], bounds[2], bounds[3]); - input.Each( - [&rng, &distribution](tensorflow::gtl::ArraySlice /* indices */, - float* cell) { *cell = distribution(rng); }); - std::unique_ptr input_literal = - LiteralUtil::CreateR4FromArray4DWithLayout( - input, LayoutUtil::MakeLayout({3, 2, 1, 0})); + input.Each([&rng, &distribution](absl::Span /* indices */, + float* cell) { *cell = distribution(rng); }); + Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( + input, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaBuilder builder(TestName()); XlaOp parameter; - auto input_data = CreateParameterAndTransferLiteral( - 0, *input_literal, "input", &builder, ¶meter); + auto input_data = CreateParameterAndTransferLiteral(0, input_literal, "input", + &builder, ¶meter); Reshape(parameter, /*dimensions=*/{0, 1, 3, 2}, /*new_sizes=*/new_bounds); - std::unique_ptr expected = - LiteralUtil::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal) - ->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0})); + Literal expected = + LiteralUtil::ReshapeSlice(new_bounds, {2, 3, 1, 0}, input_literal) + .Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0})); // Specify the requested output shape explicitly to ensure that this reshape // actually corresponds to a two minor transpose. - ComputeAndCompareLiteral(&builder, *expected, {input_data.get()}, - zero_error_spec_, &expected->shape()); + ComputeAndCompareLiteral(&builder, expected, {input_data.get()}, + zero_error_spec_, &expected.shape()); } XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1InR2) { @@ -930,27 +914,25 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1InR2) { std::vector bounds = {5, 5, 10, 1}; std::vector new_bounds = {bounds[0], bounds[1], bounds[3], bounds[2]}; Array4D input(bounds[0], bounds[1], bounds[2], bounds[3]); - input.Each( - [&rng, &distribution](tensorflow::gtl::ArraySlice /* indices */, - float* cell) { *cell = distribution(rng); }); - std::unique_ptr input_literal = - LiteralUtil::CreateR4FromArray4DWithLayout( - input, LayoutUtil::MakeLayout({3, 2, 1, 0})); + input.Each([&rng, &distribution](absl::Span /* indices */, + float* cell) { *cell = distribution(rng); }); + Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( + input, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaBuilder builder(TestName()); XlaOp parameter; - auto input_data = CreateParameterAndTransferLiteral( - 0, *input_literal, "input", &builder, ¶meter); + auto input_data = CreateParameterAndTransferLiteral(0, input_literal, "input", + &builder, ¶meter); Reshape(parameter, /*dimensions=*/{0, 1, 3, 2}, /*new_sizes=*/new_bounds); - std::unique_ptr expected = - LiteralUtil::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal) - ->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0})); + Literal expected = + LiteralUtil::ReshapeSlice(new_bounds, {2, 3, 1, 0}, input_literal) + .Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0})); // Specify the requested output shape explicitly to ensure that this reshape // actually corresponds to a two minor transpose. - ComputeAndCompareLiteral(&builder, *expected, {input_data.get()}, - zero_error_spec_, &expected->shape()); + ComputeAndCompareLiteral(&builder, expected, {input_data.get()}, + zero_error_spec_, &expected.shape()); } XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeTrivialR2) { @@ -959,27 +941,25 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeTrivialR2) { std::vector bounds = {3, 3, 1, 3}; std::vector new_bounds = {bounds[1], bounds[0], bounds[2], bounds[3]}; Array4D input(bounds[0], bounds[1], bounds[2], bounds[3]); - input.Each( - [&rng, &distribution](tensorflow::gtl::ArraySlice /* indices */, - float* cell) { *cell = distribution(rng); }); - std::unique_ptr input_literal = - LiteralUtil::CreateR4FromArray4DWithLayout( - input, LayoutUtil::MakeLayout({0, 1, 2, 3})); + input.Each([&rng, &distribution](absl::Span /* indices */, + float* cell) { *cell = distribution(rng); }); + Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( + input, LayoutUtil::MakeLayout({0, 1, 2, 3})); XlaBuilder builder(TestName()); XlaOp parameter; - auto input_data = CreateParameterAndTransferLiteral( - 0, *input_literal, "input", &builder, ¶meter); + auto input_data = CreateParameterAndTransferLiteral(0, input_literal, "input", + &builder, ¶meter); Reshape(parameter, /*dimensions=*/{1, 0, 2, 3}, /*new_sizes=*/new_bounds); - std::unique_ptr expected = - LiteralUtil::ReshapeSlice(new_bounds, {1, 0, 2, 3}, *input_literal) - ->Relayout(input_literal->shape().layout()); + Literal expected = + LiteralUtil::ReshapeSlice(new_bounds, {1, 0, 2, 3}, input_literal) + .Relayout(input_literal.shape().layout()); // Specify the requested output shape explicitly to ensure that this reshape // actually corresponds to a two minor transpose. - ComputeAndCompareLiteral(&builder, *expected, {input_data.get()}, - zero_error_spec_, &expected->shape()); + ComputeAndCompareLiteral(&builder, expected, {input_data.get()}, + zero_error_spec_, &expected.shape()); } #ifdef XLA_BACKEND_SUPPORTS_BFLOAT16 diff --git a/tensorflow/compiler/xla/tests/reverse_test.cc b/tensorflow/compiler/xla/tests/reverse_test.cc index 60084f143de5567359893a56a51719f87a720ce5..4e55b0d7ac4453d074500f3a7fda96cb5ab52c56 100644 --- a/tensorflow/compiler/xla/tests/reverse_test.cc +++ b/tensorflow/compiler/xla/tests/reverse_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include +#include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array4d.h" @@ -38,14 +39,14 @@ static std::array use_bfloat16_params{false}; #endif struct ReverseSpec { - tensorflow::gtl::ArraySlice input_dims; - tensorflow::gtl::ArraySlice reversal; + absl::Span input_dims; + absl::Span reversal; bool use_bfloat16; string ToTestCaseName() const { - return tensorflow::strings::Printf( - "reverse_%s_in_dims_%s_%s", absl::StrJoin(input_dims, "x").c_str(), - absl::StrJoin(reversal, "x").c_str(), use_bfloat16 ? "bf16" : "f32"); + return absl::StrFormat( + "reverse_%s_in_dims_%s_%s", absl::StrJoin(input_dims, "x"), + absl::StrJoin(reversal, "x"), use_bfloat16 ? "bf16" : "f32"); } }; @@ -82,26 +83,25 @@ TEST_P(FloatReverseTest, Reverses) { ShapeUtil::ElementsIn(ShapeUtil::MakeShape(F32, spec.input_dims))); std::iota(input_vector.begin(), input_vector.end(), 0.0); auto r1_literal = LiteralUtil::CreateR1(input_vector); - auto input_literal = r1_literal->Reshape(spec.input_dims).ConsumeValueOrDie(); + auto input_literal = r1_literal.Reshape(spec.input_dims).ConsumeValueOrDie(); XlaBuilder builder(TestName()); - auto a = AddParam(*input_literal, &builder); + auto a = AddParam(input_literal, &builder); Rev(a, spec.reversal); - std::unique_ptr expected = input_literal->CloneToUnique(); + Literal expected = input_literal.Clone(); std::vector output_indices(spec.input_dims.size()); - expected->EachCell( - [&](tensorflow::gtl::ArraySlice indices, float) { - for (int64 i = 0; i < indices.size(); ++i) { - output_indices[i] = indices[i]; - } - float value = input_literal->Get(indices); - for (int64 dim : spec.reversal) { - output_indices[dim] = (spec.input_dims[dim] - 1) - indices[dim]; - } - expected->Set(output_indices, value); - }); - ComputeAndCompareLiteral(&builder, *expected, {}); + expected.EachCell([&](absl::Span indices, float) { + for (int64 i = 0; i < indices.size(); ++i) { + output_indices[i] = indices[i]; + } + float value = input_literal.Get(indices); + for (int64 dim : spec.reversal) { + output_indices[dim] = (spec.input_dims[dim] - 1) - indices[dim]; + } + expected.Set(output_indices, value); + }); + ComputeAndCompareLiteral(&builder, expected, {}); } INSTANTIATE_TEST_CASE_P(FloatReverseInstance, FloatReverseTest, diff --git a/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc b/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc index a620fe19085d98c8b6642b25b159d6c2308bdae2..091a5d2cacce6ac5bf986776e5ec96612d08cc75 100644 --- a/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc +++ b/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/layout_util.h" @@ -27,7 +28,6 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/casts.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -38,7 +38,7 @@ namespace { class RoundTripPackedLiteralTest : public ClientLibraryTestBase { protected: // Sends the literal to the server and retrieves it back. - std::unique_ptr RoundTripToServer(const Literal& original) { + Literal RoundTripToServer(const Literal& original) { std::unique_ptr data = client_->TransferToServer(original).ConsumeValueOrDie(); return client_->Transfer(*data).ConsumeValueOrDie(); @@ -47,8 +47,7 @@ class RoundTripPackedLiteralTest : public ClientLibraryTestBase { TEST_F(RoundTripPackedLiteralTest, RoundTripsR1F32Length2) { string data(sizeof(float) * 2, 0); - tensorflow::gtl::MutableArraySlice floats( - tensorflow::bit_cast(data.data()), 2); + absl::Span floats(tensorflow::bit_cast(data.data()), 2); floats[0] = 42.0; floats[1] = 24.0; @@ -60,18 +59,17 @@ TEST_F(RoundTripPackedLiteralTest, RoundTripsR1F32Length2) { std::unique_ptr f; TF_CHECK_OK(tensorflow::Env::Default()->NewRandomAccessFile(fname, &f)); PackedLiteralReader reader(f.release()); - std::unique_ptr actual = + Literal actual = reader.Read(ShapeUtil::MakeShape(F32, {2})).ConsumeValueOrDie(); EXPECT_TRUE(reader.IsExhausted()); - EXPECT_EQ(42.0, actual->Get({0})); - EXPECT_EQ(24.0, actual->Get({1})); + EXPECT_EQ(42.0, actual.Get({0})); + EXPECT_EQ(24.0, actual.Get({1})); } TEST_F(RoundTripPackedLiteralTest, RoundTripsR2F32Size2x2Dim0Minor) { string data(sizeof(float) * 4, 0); - tensorflow::gtl::MutableArraySlice floats( - tensorflow::bit_cast(data.data()), 4); + absl::Span floats(tensorflow::bit_cast(data.data()), 4); // With x as the minor dimension, these will become: floats[0] = 42.0; // y=0,x=0 floats[1] = 24.0; // y=0,x=1 @@ -89,24 +87,22 @@ TEST_F(RoundTripPackedLiteralTest, RoundTripsR2F32Size2x2Dim0Minor) { std::unique_ptr f; TF_CHECK_OK(tensorflow::Env::Default()->NewRandomAccessFile(fname, &f)); PackedLiteralReader reader(f.release()); - std::unique_ptr actual = - reader.Read(ShapeUtil::MakeShape(F32, {2, 2}), &layout) - .ConsumeValueOrDie(); + Literal actual = reader.Read(ShapeUtil::MakeShape(F32, {2, 2}), &layout) + .ConsumeValueOrDie(); EXPECT_TRUE(reader.IsExhausted()); - EXPECT_EQ(42.0f, actual->Get({0, 0})); - EXPECT_EQ(24.0f, actual->Get({0, 1})); - EXPECT_EQ(64.0f, actual->Get({1, 0})); - EXPECT_EQ(46.0f, actual->Get({1, 1})); + EXPECT_EQ(42.0f, actual.Get({0, 0})); + EXPECT_EQ(24.0f, actual.Get({0, 1})); + EXPECT_EQ(64.0f, actual.Get({1, 0})); + EXPECT_EQ(46.0f, actual.Get({1, 1})); - std::unique_ptr round_tripped = RoundTripToServer(*actual); - EXPECT_TRUE(LiteralTestUtil::Equal(*round_tripped, *actual)); + Literal round_tripped = RoundTripToServer(actual); + EXPECT_TRUE(LiteralTestUtil::Equal(round_tripped, actual)); } TEST_F(RoundTripPackedLiteralTest, RoundTripsR2F32Size2x2Dim1Minor) { string data(sizeof(float) * 4, 0); - tensorflow::gtl::MutableArraySlice floats( - tensorflow::bit_cast(data.data()), 4); + absl::Span floats(tensorflow::bit_cast(data.data()), 4); // With y as the minor dimension, these will become: floats[0] = 42.0; // y=0,x=0 floats[1] = 24.0; // y=1,x=0 @@ -124,18 +120,17 @@ TEST_F(RoundTripPackedLiteralTest, RoundTripsR2F32Size2x2Dim1Minor) { std::unique_ptr f; TF_CHECK_OK(tensorflow::Env::Default()->NewRandomAccessFile(fname, &f)); PackedLiteralReader reader(f.release()); - std::unique_ptr actual = - reader.Read(ShapeUtil::MakeShape(F32, {2, 2}), &layout) - .ConsumeValueOrDie(); + Literal actual = reader.Read(ShapeUtil::MakeShape(F32, {2, 2}), &layout) + .ConsumeValueOrDie(); EXPECT_TRUE(reader.IsExhausted()); - EXPECT_EQ(42.0f, actual->Get({0, 0})); - EXPECT_EQ(24.0f, actual->Get({1, 0})); - EXPECT_EQ(64.0f, actual->Get({0, 1})); - EXPECT_EQ(46.0f, actual->Get({1, 1})); + EXPECT_EQ(42.0f, actual.Get({0, 0})); + EXPECT_EQ(24.0f, actual.Get({1, 0})); + EXPECT_EQ(64.0f, actual.Get({0, 1})); + EXPECT_EQ(46.0f, actual.Get({1, 1})); - std::unique_ptr round_tripped = RoundTripToServer(*actual); - EXPECT_TRUE(LiteralTestUtil::Equal(*round_tripped, *actual)); + Literal round_tripped = RoundTripToServer(actual); + EXPECT_TRUE(LiteralTestUtil::Equal(round_tripped, actual)); } } // namespace diff --git a/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc b/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc index a8193c2eac05ba4f0df339909f3e82a28ac35253..cd5a531603b0cb6e0f48f4dcd49891cbd5428602 100644 --- a/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc +++ b/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc @@ -39,69 +39,67 @@ class RoundTripTransferTest : public ClientLibraryTestBase { void RoundTripTest(const Literal& original) { std::unique_ptr data = client_->TransferToServer(original).ConsumeValueOrDie(); - std::unique_ptr result = - client_->Transfer(*data).ConsumeValueOrDie(); - EXPECT_TRUE(LiteralTestUtil::Equal(original, *result)); + Literal result = client_->Transfer(*data).ConsumeValueOrDie(); + EXPECT_TRUE(LiteralTestUtil::Equal(original, result)); } }; TEST_F(RoundTripTransferTest, R0S32) { - RoundTripTest(*LiteralUtil::CreateR0(42)); + RoundTripTest(LiteralUtil::CreateR0(42)); } TEST_F(RoundTripTransferTest, R0F32) { - RoundTripTest(*LiteralUtil::CreateR0(42.0)); + RoundTripTest(LiteralUtil::CreateR0(42.0)); } TEST_F(RoundTripTransferTest, R1F32_Len0) { - RoundTripTest(*LiteralUtil::CreateR1({})); + RoundTripTest(LiteralUtil::CreateR1({})); } TEST_F(RoundTripTransferTest, R1F32_Len2) { - RoundTripTest(*LiteralUtil::CreateR1({42.0, 64.0})); + RoundTripTest(LiteralUtil::CreateR1({42.0, 64.0})); } TEST_F(RoundTripTransferTest, R1F32_Len256) { std::vector values(256); std::iota(values.begin(), values.end(), 1.0); - RoundTripTest(*LiteralUtil::CreateR1(values)); + RoundTripTest(LiteralUtil::CreateR1(values)); } TEST_F(RoundTripTransferTest, R1F32_Len1024) { std::vector values(1024); std::iota(values.begin(), values.end(), 1.0); - RoundTripTest(*LiteralUtil::CreateR1(values)); + RoundTripTest(LiteralUtil::CreateR1(values)); } TEST_F(RoundTripTransferTest, R1F32_Len1025) { std::vector values(1025); std::iota(values.begin(), values.end(), 1.0); - RoundTripTest(*LiteralUtil::CreateR1(values)); + RoundTripTest(LiteralUtil::CreateR1(values)); } TEST_F(RoundTripTransferTest, R1F32_Len4096) { std::vector values(4096); std::iota(values.begin(), values.end(), 1.0); - RoundTripTest(*LiteralUtil::CreateR1(values)); + RoundTripTest(LiteralUtil::CreateR1(values)); } TEST_F(RoundTripTransferTest, R2F32_Len10x0) { - RoundTripTest( - *LiteralUtil::CreateR2FromArray2D(Array2D(10, 0))); + RoundTripTest(LiteralUtil::CreateR2FromArray2D(Array2D(10, 0))); } TEST_F(RoundTripTransferTest, R2F32_Len2x2) { - RoundTripTest(*LiteralUtil::CreateR2({{42.0, 64.0}, {77.0, 88.0}})); + RoundTripTest(LiteralUtil::CreateR2({{42.0, 64.0}, {77.0, 88.0}})); } TEST_F(RoundTripTransferTest, R3F32) { RoundTripTest( - *LiteralUtil::CreateR3({{{1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}}, - {{3.0, 4.0}, {3.0, 4.0}, {3.0, 4.0}}})); + LiteralUtil::CreateR3({{{1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}}, + {{3.0, 4.0}, {3.0, 4.0}, {3.0, 4.0}}})); } TEST_F(RoundTripTransferTest, R4F32) { - RoundTripTest(*LiteralUtil::CreateR4({{ + RoundTripTest(LiteralUtil::CreateR4({{ {{10, 11, 12, 13}, {14, 15, 16, 17}}, {{18, 19, 20, 21}, {22, 23, 24, 25}}, {{26, 27, 28, 29}, {30, 31, 32, 33}}, @@ -109,36 +107,35 @@ TEST_F(RoundTripTransferTest, R4F32) { } TEST_F(RoundTripTransferTest, EmptyTuple) { - RoundTripTest(*LiteralUtil::MakeTuple({})); + RoundTripTest(LiteralUtil::MakeTuple({})); } TEST_F(RoundTripTransferTest, TupleOfR1F32) { RoundTripTest( - *LiteralUtil::MakeTuple({LiteralUtil::CreateR1({1, 2}).get(), - LiteralUtil::CreateR1({3, 4}).get()})); + LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1({1, 2}), + LiteralUtil::CreateR1({3, 4})})); } TEST_F(RoundTripTransferTest, TupleOfR1F32_Len0_Len2) { RoundTripTest( - *LiteralUtil::MakeTuple({LiteralUtil::CreateR1({}).get(), - LiteralUtil::CreateR1({3, 4}).get()})); + LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1({}), + LiteralUtil::CreateR1({3, 4})})); } TEST_F(RoundTripTransferTest, TupleOfR0F32AndR1S32) { - RoundTripTest( - *LiteralUtil::MakeTuple({LiteralUtil::CreateR0(1.0).get(), - LiteralUtil::CreateR1({2, 3}).get()})); + RoundTripTest(LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR0(1.0), LiteralUtil::CreateR1({2, 3})})); } // Below two tests are added to identify the cost of large data transfers. TEST_F(RoundTripTransferTest, R2F32_Large) { - RoundTripTest(*LiteralUtil::CreateR2F32Linspace(-1.0f, 1.0f, 512, 512)); + RoundTripTest(LiteralUtil::CreateR2F32Linspace(-1.0f, 1.0f, 512, 512)); } TEST_F(RoundTripTransferTest, R4F32_Large) { Array4D array4d(2, 2, 256, 256); array4d.FillWithMultiples(1.0f); - RoundTripTest(*LiteralUtil::CreateR4FromArray4D(array4d)); + RoundTripTest(LiteralUtil::CreateR4FromArray4D(array4d)); } } // namespace diff --git a/tensorflow/compiler/xla/tests/scalar_computations_test.cc b/tensorflow/compiler/xla/tests/scalar_computations_test.cc index cf2d453f43cda88ca05ab211a9b8be6e9c3e7c63..1dd937a6d0656b53f9e7e0cb25acf80f0c3d59c0 100644 --- a/tensorflow/compiler/xla/tests/scalar_computations_test.cc +++ b/tensorflow/compiler/xla/tests/scalar_computations_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include "absl/strings/str_cat.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_builder.h" @@ -31,7 +32,6 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -46,9 +46,8 @@ class ScalarComputationsTest : public ClientLibraryTestBase { // A template for building and running a binary comparison test. template void TestCompare(NativeT lhs, NativeT rhs, bool expected, - std::function)> - op) { + const std::function)>& op) { XlaBuilder builder(TestName()); XlaOp lhs_op = ConstantR0(&builder, lhs); XlaOp rhs_op = ConstantR0(&builder, rhs); @@ -58,9 +57,8 @@ class ScalarComputationsTest : public ClientLibraryTestBase { template void TestMinMax(NativeT lhs, NativeT rhs, NativeT expected, - std::function)> - op) { + const std::function)>& op) { XlaBuilder builder(TestName()); XlaOp lhs_op = ConstantR0(&builder, lhs); XlaOp rhs_op = ConstantR0(&builder, rhs); @@ -163,9 +161,9 @@ XLA_TEST_F(ScalarComputationsTest, CastS64ToF32) { ConvertElementType(a, F32); int64 value = 3LL << 35; - std::unique_ptr a_literal = LiteralUtil::CreateR0(value); + Literal a_literal = LiteralUtil::CreateR0(value); std::unique_ptr a_data = - client_->TransferToServer(*a_literal).ConsumeValueOrDie(); + client_->TransferToServer(a_literal).ConsumeValueOrDie(); ComputeAndCompareR0(&builder, static_cast(value), {a_data.get()}); } @@ -227,20 +225,20 @@ XLA_TEST_F(ScalarComputationsTest, MulThreeScalarsS32) { XLA_TEST_F(ScalarComputationsTest, MulThreeScalarsF32Params) { XlaBuilder builder(TestName()); - std::unique_ptr a_literal = LiteralUtil::CreateR0(2.1f); - std::unique_ptr b_literal = LiteralUtil::CreateR0(5.5f); - std::unique_ptr c_literal = LiteralUtil::CreateR0(0.5f); + Literal a_literal = LiteralUtil::CreateR0(2.1f); + Literal b_literal = LiteralUtil::CreateR0(5.5f); + Literal c_literal = LiteralUtil::CreateR0(0.5f); std::unique_ptr a_data = - client_->TransferToServer(*a_literal).ConsumeValueOrDie(); + client_->TransferToServer(a_literal).ConsumeValueOrDie(); std::unique_ptr b_data = - client_->TransferToServer(*b_literal).ConsumeValueOrDie(); + client_->TransferToServer(b_literal).ConsumeValueOrDie(); std::unique_ptr c_data = - client_->TransferToServer(*c_literal).ConsumeValueOrDie(); + client_->TransferToServer(c_literal).ConsumeValueOrDie(); - XlaOp a = Parameter(&builder, 0, a_literal->shape(), "a"); - XlaOp b = Parameter(&builder, 1, b_literal->shape(), "b"); - XlaOp c = Parameter(&builder, 2, c_literal->shape(), "c"); + XlaOp a = Parameter(&builder, 0, a_literal.shape(), "a"); + XlaOp b = Parameter(&builder, 1, b_literal.shape(), "b"); + XlaOp c = Parameter(&builder, 2, c_literal.shape(), "c"); Mul(Mul(a, b), c); ComputeAndCompareR0(&builder, 5.775f, @@ -379,9 +377,9 @@ XLA_TEST_F(ScalarComputationsTest, DivU32s) { auto dividend_literal = LiteralUtil::CreateR0(dividend); auto divisor_literal = LiteralUtil::CreateR0(divisor); TF_ASSERT_OK_AND_ASSIGN(auto dividend_data, - client_->TransferToServer(*dividend_literal)); + client_->TransferToServer(dividend_literal)); TF_ASSERT_OK_AND_ASSIGN(auto divisor_data, - client_->TransferToServer(*divisor_literal)); + client_->TransferToServer(divisor_literal)); auto actual_literal = client_ ->ExecuteAndTransfer(div_computation, @@ -390,7 +388,7 @@ XLA_TEST_F(ScalarComputationsTest, DivU32s) { .ConsumeValueOrDie(); auto expected_literal = LiteralUtil::CreateR0(dividend / divisor); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *actual_literal)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected_literal, actual_literal)); } } } @@ -421,9 +419,9 @@ XLA_TEST_F(ScalarComputationsTest, RemU32s) { auto dividend_literal = LiteralUtil::CreateR0(dividend); auto divisor_literal = LiteralUtil::CreateR0(divisor); TF_ASSERT_OK_AND_ASSIGN(auto dividend_data, - client_->TransferToServer(*dividend_literal)); + client_->TransferToServer(dividend_literal)); TF_ASSERT_OK_AND_ASSIGN(auto divisor_data, - client_->TransferToServer(*divisor_literal)); + client_->TransferToServer(divisor_literal)); auto actual_literal = client_ ->ExecuteAndTransfer(rem_computation, @@ -432,7 +430,7 @@ XLA_TEST_F(ScalarComputationsTest, RemU32s) { .ConsumeValueOrDie(); auto expected_literal = LiteralUtil::CreateR0(dividend % divisor); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *actual_literal)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected_literal, actual_literal)); } } } @@ -443,8 +441,8 @@ XLA_TEST_F(ScalarComputationsTest, RemainderTwoScalarsNonConstDividendS32) { auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(S32, {}), "x"); Rem(x, ConstantR0(&builder, 80000)); - std::unique_ptr literal = LiteralUtil::CreateR0(87919); - TF_ASSERT_OK_AND_ASSIGN(auto input_data, client_->TransferToServer(*literal)); + Literal literal = LiteralUtil::CreateR0(87919); + TF_ASSERT_OK_AND_ASSIGN(auto input_data, client_->TransferToServer(literal)); ComputeAndCompareR0(&builder, 7919, {input_data.get()}); } diff --git a/tensorflow/compiler/xla/tests/scatter_test.cc b/tensorflow/compiler/xla/tests/scatter_test.cc index 99eeb12e2bdd4e8ece42bcd8ffef35b37dfaac48..7e1f4aa0eb4801876d9bdbac6a4d7f1d09f81ba8 100644 --- a/tensorflow/compiler/xla/tests/scatter_test.cc +++ b/tensorflow/compiler/xla/tests/scatter_test.cc @@ -32,8 +32,7 @@ class ScatterTest : public HloTestBase { RunTest(hlo_text, {operand, scatter_indices, updates}); } - void RunTest(const string& hlo_text, - tensorflow::gtl::ArraySlice args) { + void RunTest(const string& hlo_text, absl::Span args) { HloModuleConfig config; config.set_debug_options(GetDebugOptionsForTest()); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, @@ -63,13 +62,42 @@ ENTRY main { index_vector_dim=1 } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr scatter_indices = - LiteralUtil::CreateR1({0, 2}); - std::unique_ptr updates = - LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); - RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); + Literal scatter_indices = LiteralUtil::CreateR1({0, 2}); + Literal updates = LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); + RunTest(hlo_text, &operand, &scatter_indices, &updates); +} + +XLA_TEST_F(ScatterTest, TensorFlowScatterV1_WithFusedAdds) { + const string hlo_text = R"( +HloModule TensorFlowScatterV1 + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + p0 = s32[3,3] parameter(0) + operand = s32[3,3] add(p0, p0) + p1 = s32[2] parameter(1) + indices = s32[2] add(p1, p1) + p2 = s32[2,3] parameter(2) + updates = s32[2,3] add(p2, p2) + ROOT scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 +} +)"; + Literal operand = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + Literal scatter_indices = LiteralUtil::CreateR1({0, 1}); + Literal updates = LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); + RunTest(hlo_text, &operand, &scatter_indices, &updates); } XLA_TEST_F(ScatterTest, TensorFlowScatterV2_Update) { @@ -93,13 +121,43 @@ ENTRY main { index_vector_dim=1 } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr scatter_indices = - LiteralUtil::CreateR1({0, 2}); - std::unique_ptr updates = + Literal scatter_indices = LiteralUtil::CreateR1({0, 2}); + Literal updates = LiteralUtil::CreateR2({{10, 30}, {40, 60}, {70, 90}}); - RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); + RunTest(hlo_text, &operand, &scatter_indices, &updates); +} + +XLA_TEST_F(ScatterTest, SimpleR4) { + const char* hlo_text = R"( +HloModule SimpleR4 + +add_f32 (lhs: f32[], rhs: f32[]) -> f32[] { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(f32[] lhs, f32[] rhs) +} + +ENTRY main { + operand = f32[1,2,2,1] parameter(0) + indices = s32[1,3] parameter(1) + updates = f32[1,2,2,1] parameter(2) + ROOT scatter = f32[1,2,2,1] scatter(operand, indices, updates), + to_apply=add_f32, + update_window_dims={1,2,3}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0, 2, 1}, + index_vector_dim=1 +} +)"; + + Literal operand = + LiteralUtil::CreateR4({{{{0.f}, {0.f}}, {{0.f}, {0.f}}}}); + Literal updates = + LiteralUtil::CreateR4({{{{0.12}, {0.28}}, {{0.018}, {0.42}}}}); + Literal scatter_indices = LiteralUtil::CreateR2({{0, 0, 0}}); + RunTest(hlo_text, &operand, &scatter_indices, &updates); } XLA_TEST_F(ScatterTest, TensorFlowScatter_Add) { @@ -124,13 +182,11 @@ ENTRY main { index_vector_dim=1 } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr scatter_indices = - LiteralUtil::CreateR1({0, 2}); - std::unique_ptr updates = - LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); - RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); + Literal scatter_indices = LiteralUtil::CreateR1({0, 2}); + Literal updates = LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); + RunTest(hlo_text, &operand, &scatter_indices, &updates); } XLA_TEST_F(ScatterTest, TensorFlowScatter_Mul) { @@ -155,13 +211,11 @@ ENTRY main { index_vector_dim=1 } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr scatter_indices = - LiteralUtil::CreateR1({0, 2}); - std::unique_ptr updates = - LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); - RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); + Literal scatter_indices = LiteralUtil::CreateR1({0, 2}); + Literal updates = LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); + RunTest(hlo_text, &operand, &scatter_indices, &updates); } XLA_TEST_F(ScatterTest, TensorFlowScatter_F32) { @@ -186,13 +240,12 @@ ENTRY main { index_vector_dim=1 } )"; - std::unique_ptr operand = LiteralUtil::CreateR2( + Literal operand = LiteralUtil::CreateR2( {{1.1, 2.2, 3.3}, {4.4, 5.5, 6.6}, {7.7, 8.8, 9.9}}); - std::unique_ptr scatter_indices = - LiteralUtil::CreateR1({2, 1}); - std::unique_ptr updates = + Literal scatter_indices = LiteralUtil::CreateR1({2, 1}); + Literal updates = LiteralUtil::CreateR2({{0.4, 1.1, 0.7}, {2.3, 3.1, 1.6}}); - RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); + RunTest(hlo_text, &operand, &scatter_indices, &updates); } XLA_TEST_F(ScatterTest, TensorFlowScatter_RepeatedIndices) { @@ -217,13 +270,11 @@ ENTRY main { index_vector_dim=1 } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr scatter_indices = - LiteralUtil::CreateR1({1, 1}); - std::unique_ptr updates = - LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); - RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); + Literal scatter_indices = LiteralUtil::CreateR1({1, 1}); + Literal updates = LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); + RunTest(hlo_text, &operand, &scatter_indices, &updates); } XLA_TEST_F(ScatterTest, TensorFlowScatter_MultipleBatchDims) { @@ -248,13 +299,12 @@ ENTRY main { index_vector_dim=2 } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr scatter_indices = - LiteralUtil::CreateR2({{0, 2}, {2, 1}}); - std::unique_ptr updates = LiteralUtil::CreateR3( + Literal scatter_indices = LiteralUtil::CreateR2({{0, 2}, {2, 1}}); + Literal updates = LiteralUtil::CreateR3( {{{10, 30}, {40, 60}, {70, 90}}, {{5, 5}, {5, 5}, {5, 5}}}); - RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); + RunTest(hlo_text, &operand, &scatter_indices, &updates); } XLA_TEST_F(ScatterTest, TensorFlowScatterNd) { @@ -278,15 +328,13 @@ ENTRY main { index_vector_dim=1 } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // {{-4, 4}, {-5, 5}, {-6, 6}}, // {{-7, 7}, {-8, 8}, {-9, 9}}}); - std::unique_ptr scatter_indices = - LiteralUtil::CreateR2({{0, 0}, {1, 0}}); - std::unique_ptr updates = - LiteralUtil::CreateR2({{-10, 10}, {-40, 40}}); - RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); + Literal scatter_indices = LiteralUtil::CreateR2({{0, 0}, {1, 0}}); + Literal updates = LiteralUtil::CreateR2({{-10, 10}, {-40, 40}}); + RunTest(hlo_text, &operand, &scatter_indices, &updates); } XLA_TEST_F(ScatterTest, TensorFlowScatterNd_NonDefaultIndexVectorDim) { @@ -310,15 +358,13 @@ ENTRY main { index_vector_dim=0 } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // {{-4, 4}, {-5, 5}, {-6, 6}}, // {{-7, 7}, {-8, 8}, {-9, 9}}}); - std::unique_ptr scatter_indices = - LiteralUtil::CreateR2({{0, 0}, {1, 0}}); - std::unique_ptr updates = - LiteralUtil::CreateR2({{-10, 10}, {-20, 20}}); - RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); + Literal scatter_indices = LiteralUtil::CreateR2({{0, 0}, {1, 0}}); + Literal updates = LiteralUtil::CreateR2({{-10, 10}, {-20, 20}}); + RunTest(hlo_text, &operand, &scatter_indices, &updates); } XLA_TEST_F(ScatterTest, DynamicUpdateSlice) { @@ -342,12 +388,11 @@ ENTRY main { index_vector_dim=0 } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr scatter_indices = - LiteralUtil::CreateR1({1, 1}); - std::unique_ptr updates = LiteralUtil::CreateR2({{10}}); - RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); + Literal scatter_indices = LiteralUtil::CreateR1({1, 1}); + Literal updates = LiteralUtil::CreateR2({{10}}); + RunTest(hlo_text, &operand, &scatter_indices, &updates); } XLA_TEST_F(ScatterTest, BatchDynamicUpdateSlice) { @@ -371,13 +416,11 @@ ENTRY main { index_vector_dim=0 } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr scatter_indices = - LiteralUtil::CreateR2({{2, 1}, {1, 1}}); - std::unique_ptr updates = - LiteralUtil::CreateR3({{{10}}, {{20}}}); - RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); + Literal scatter_indices = LiteralUtil::CreateR2({{2, 1}, {1, 1}}); + Literal updates = LiteralUtil::CreateR3({{{10}}, {{20}}}); + RunTest(hlo_text, &operand, &scatter_indices, &updates); } XLA_TEST_F(ScatterTest, ZeroDimBounds) { @@ -401,11 +444,10 @@ ENTRY main { index_vector_dim=1 } )"; - std::unique_ptr operand = LiteralUtil::CreateR2({{}, {}, {}}); - std::unique_ptr scatter_indices = - LiteralUtil::CreateR1({0, 2}); - std::unique_ptr updates = LiteralUtil::CreateR2({{}, {}}); - RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); + Literal operand = LiteralUtil::CreateR2({{}, {}, {}}); + Literal scatter_indices = LiteralUtil::CreateR1({0, 2}); + Literal updates = LiteralUtil::CreateR2({{}, {}}); + RunTest(hlo_text, &operand, &scatter_indices, &updates); } XLA_TEST_F(ScatterTest, NoUpdateWindowDims) { @@ -430,12 +472,11 @@ ENTRY main { index_vector_dim=2 } )"; - std::unique_ptr operand = LiteralUtil::CreateR1({0, 1, 2}); - std::unique_ptr scatter_indices = + Literal operand = LiteralUtil::CreateR1({0, 1, 2}); + Literal scatter_indices = LiteralUtil::CreateR3({{{0}, {1}}, {{2}, {1}}}); - std::unique_ptr updates = - LiteralUtil::CreateR2({{10, 20}, {30, 40}}); - RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); + Literal updates = LiteralUtil::CreateR2({{10, 20}, {30, 40}}); + RunTest(hlo_text, &operand, &scatter_indices, &updates); } XLA_TEST_F(ScatterTest, OutOfBoundsIndex) { @@ -459,13 +500,13 @@ ENTRY main { index_vector_dim=1 } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr scatter_indices = LiteralUtil::CreateR2( + Literal scatter_indices = LiteralUtil::CreateR2( {{2, 7}, {2, 1}, {1, 1}, {5, 1}, {2147483647, 1}, {1, 2}}); - std::unique_ptr updates = LiteralUtil::CreateR3( + Literal updates = LiteralUtil::CreateR3( {{{10}}, {{20}}, {{30}}, {{40}}, {{50}}, {{60}}}); - RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); + RunTest(hlo_text, &operand, &scatter_indices, &updates); } XLA_TEST_F(ScatterTest, OutOfBoundsUnsignedIndex) { @@ -489,13 +530,13 @@ ENTRY main { index_vector_dim=1 } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr scatter_indices = LiteralUtil::CreateR2( + Literal scatter_indices = LiteralUtil::CreateR2( {{2, 7}, {2, 1}, {1, 1}, {5, 1}, {2147483648u, 1}, {1, 2}}); - std::unique_ptr updates = LiteralUtil::CreateR3( + Literal updates = LiteralUtil::CreateR3( {{{10}}, {{20}}, {{30}}, {{40}}, {{50}}, {{60}}}); - RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); + RunTest(hlo_text, &operand, &scatter_indices, &updates); } XLA_TEST_F(ScatterTest, NegativeIndex) { @@ -519,13 +560,43 @@ ENTRY main { index_vector_dim=1 } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr scatter_indices = LiteralUtil::CreateR2( + Literal scatter_indices = LiteralUtil::CreateR2( {{2, 7}, {2, 1}, {1, 1}, {-500, 1}, {-2147483648, 1}, {1, 2}}); - std::unique_ptr updates = LiteralUtil::CreateR3( + Literal updates = LiteralUtil::CreateR3( {{{10}}, {{20}}, {{30}}, {{40}}, {{50}}, {{60}}}); - RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); + RunTest(hlo_text, &operand, &scatter_indices, &updates); +} + +XLA_TEST_F(ScatterTest, OutOfBoundsUpdateWindow) { + const char* hlo_text = R"( +HloModule TensorFlowScatterNd_OobUpdateWindow + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[3,3,2] parameter(0) + indices = s32[1,2] parameter(1) + updates = s32[1,2,2] parameter(2) + ROOT scatter = s32[3,3,2] scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={1,2}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0,1}, + index_vector_dim=1 +} +)"; + Literal operand = + LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // + {{-4, 4}, {-5, 5}, {-6, 6}}, // + {{-7, 7}, {-8, 8}, {-9, 9}}}); + Literal scatter_indices = LiteralUtil::CreateR2({{0, 2}}); + Literal updates = LiteralUtil::CreateR3({{{-10, 10}, {-40, 40}}}); + RunTest(hlo_text, &operand, &scatter_indices, &updates); } XLA_TEST_F(ScatterTest, OneScalarIndex) { @@ -549,12 +620,12 @@ ENTRY main { index_vector_dim=0 } )"; - std::unique_ptr operand = LiteralUtil::CreateR3( + Literal operand = LiteralUtil::CreateR3( {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}}); - std::unique_ptr scatter_indices = LiteralUtil::CreateR0(1); - std::unique_ptr updates = + Literal scatter_indices = LiteralUtil::CreateR0(1); + Literal updates = LiteralUtil::CreateR3({{{10, 20}, {30, 40}, {50, 60}}}); - RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); + RunTest(hlo_text, &operand, &scatter_indices, &updates); } XLA_TEST_F(ScatterTest, ScalarUpdate) { @@ -578,10 +649,10 @@ ENTRY main { index_vector_dim=0 } )"; - std::unique_ptr operand = LiteralUtil::CreateR1({1, 2, 3, 4}); - std::unique_ptr scatter_indices = LiteralUtil::CreateR0(1); - std::unique_ptr updates = LiteralUtil::CreateR0(25); - RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); + Literal operand = LiteralUtil::CreateR1({1, 2, 3, 4}); + Literal scatter_indices = LiteralUtil::CreateR0(1); + Literal updates = LiteralUtil::CreateR0(25); + RunTest(hlo_text, &operand, &scatter_indices, &updates); } XLA_TEST_F(ScatterTest, EmptyIndices) { @@ -605,10 +676,10 @@ ENTRY main { index_vector_dim=1 } )"; - std::unique_ptr operand = LiteralUtil::CreateR1({1, 2, 3}); - std::unique_ptr scatter_indices = LiteralUtil::CreateR1({}); - std::unique_ptr updates = LiteralUtil::CreateR1({}); - RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); + Literal operand = LiteralUtil::CreateR1({1, 2, 3}); + Literal scatter_indices = LiteralUtil::CreateR1({}); + Literal updates = LiteralUtil::CreateR1({}); + RunTest(hlo_text, &operand, &scatter_indices, &updates); } } // namespace diff --git a/tensorflow/compiler/xla/tests/select_and_scatter_test.cc b/tensorflow/compiler/xla/tests/select_and_scatter_test.cc index e3d4f98dd7432d1dce7e697586e8b17105dc82e7..f737b5158b3622d677aea5bf64a421a56e2c42dd 100644 --- a/tensorflow/compiler/xla/tests/select_and_scatter_test.cc +++ b/tensorflow/compiler/xla/tests/select_and_scatter_test.cc @@ -42,8 +42,8 @@ struct SelectAndScatterTestParam { std::vector operand_shape; std::vector source_shape; Padding padding_type; - tensorflow::gtl::ArraySlice window_dimensions; - tensorflow::gtl::ArraySlice window_strides; + absl::Span window_dimensions; + absl::Span window_strides; }; class SelectAndScatterTest diff --git a/tensorflow/compiler/xla/tests/slice_test.cc b/tensorflow/compiler/xla/tests/slice_test.cc index c57bbbd1e4573003d2824aea5fcef36dc55238b5..2cc33ab0963afe8ba2d8e9a6972dcf0622e27c48 100644 --- a/tensorflow/compiler/xla/tests/slice_test.cc +++ b/tensorflow/compiler/xla/tests/slice_test.cc @@ -20,7 +20,9 @@ limitations under the License. #include "absl/container/inlined_vector.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "absl/strings/str_join.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_builder.h" @@ -28,8 +30,6 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" -#include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -176,8 +176,8 @@ XLA_TEST_F(SliceTest, StridedSliceR4WithOutputLayout) { XlaBuilder builder(TestName()); auto original = ConstantR4FromArray4D(&builder, values); Slice(original, {0, 0, 0, 0}, {2, 4, 6, 8}, {1, 1, 2, 1}); - ComputeAndCompareLiteral(&builder, *expected_literal, {}, ErrorSpec(0.000001), - &expected_literal->shape()); + ComputeAndCompareLiteral(&builder, expected_literal, {}, ErrorSpec(0.000001), + &expected_literal.shape()); } struct R1Spec { @@ -194,14 +194,14 @@ class SliceR1Test : public ClientLibraryTestBase, protected: template void Run(const R1Spec& spec) { - // This can't be an std::vector, since you can't grab an ArraySlice of a + // This can't be an std::vector, since you can't grab a Span of a // vector. absl::InlinedVector input(spec.input_dim0); std::iota(input.begin(), input.end(), NativeT()); auto literal = LiteralUtil::CreateR1(input); XlaBuilder builder(TestName()); - auto original = Parameter(&builder, 0, literal->shape(), "p0"); + auto original = Parameter(&builder, 0, literal.shape(), "p0"); Slice(original, {spec.slice_start}, {spec.slice_limit}, {spec.slice_stride}); @@ -213,7 +213,7 @@ class SliceR1Test : public ClientLibraryTestBase, } TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr arg, - client_->TransferToServer(*literal)); + client_->TransferToServer(literal)); ComputeAndCompareR1(&builder, expected, {arg.get()}); } }; @@ -223,9 +223,8 @@ class SliceR1LargeTest : public SliceR1Test {}; string SliceR1TestDataToString(const ::testing::TestParamInfo& data) { const R1Spec& spec = data.param; - return ::tensorflow::strings::Printf("%lld_%lld_%lld_%lld", spec.input_dim0, - spec.slice_start, spec.slice_limit, - spec.slice_stride); + return absl::StrFormat("%d_%d_%d_%d", spec.input_dim0, spec.slice_start, + spec.slice_limit, spec.slice_stride); } XLA_TEST_P(SliceR1Test, DoIt_F32) { Run(GetParam()); } @@ -377,11 +376,11 @@ XLA_TEST_P(SliceR2Test, DoIt) { input, LayoutUtil::MakeLayout(spec.layout)); XlaBuilder builder(TestName()); - auto a = Parameter(&builder, 0, literal->shape(), "p0"); + auto a = Parameter(&builder, 0, literal.shape(), "p0"); Slice(a, spec.slice_starts, spec.slice_limits, spec.slice_strides); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr arg, - client_->TransferToServer(*literal)); + client_->TransferToServer(literal)); std::unique_ptr> expected = ReferenceUtil::Slice2D( input, spec.slice_starts, spec.slice_limits, spec.slice_strides); ComputeAndCompareR2(&builder, *expected, {arg.get()}); @@ -413,6 +412,7 @@ INSTANTIATE_TEST_CASE_P( R2Spec{511, 513, {{129, 300}}, {{400, 500}}, {{7, 11}}, {{0, 1}}}, // R2Spec{511, 513, {{129, 300}}, {{400, 500}}, {{11, 7}}, {{1, 0}}}, // R2Spec{511, 513, {{129, 300}}, {{400, 500}}, {{11, 7}}, {{0, 1}}}, // + R2Spec{8672, 512, {{8, 0}}, {{8672, 512}}, {{542, 1}}, {{1, 0}}}, // R2Spec{ 511, 513, {{129, 300}}, {{400, 500}}, {{101, 129}}, {{1, 0}}}, // R2Spec{ @@ -468,9 +468,9 @@ class SliceR4Test : public ClientLibraryTestBase, XlaBuilder builder(TestName()); auto literal = LiteralUtil::CreateR4FromArray4DWithLayout( values, LayoutUtil::MakeLayout(spec.input_layout)); - auto parameter = Parameter(&builder, 0, literal->shape(), "p0"); + auto parameter = Parameter(&builder, 0, literal.shape(), "p0"); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr arg, - client_->TransferToServer(*literal)); + client_->TransferToServer(literal)); Slice(parameter, spec.slice_starts, spec.slice_limits, spec.slice_strides); ComputeAndCompareR4(&builder, *expected, {arg.get()}, ErrorSpec(0.000001)); } diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc index 21c58e075e747af808bd36b54e903c3063149af4..2f18036ff4c5b0bfa28723fb181c33fa6995eb80 100644 --- a/tensorflow/compiler/xla/tests/test_utils.cc +++ b/tensorflow/compiler/xla/tests/test_utils.cc @@ -116,13 +116,14 @@ void PopulateWithRandomIntegralData(Literal* literal, std::minstd_rand0* engine, // array. This is uniqueness is best-effort only. Some types (half and bfloat16) // are not supported and uniqueness cannot be guaranteed if the number of // elements exceeds the number of different values supported by the type. -StatusOr> MakeFakeLiteralInternal( - const Shape& shape, std::minstd_rand0* engine, bool no_duplicates) { +StatusOr MakeFakeLiteralInternal(const Shape& shape, + std::minstd_rand0* engine, + bool no_duplicates) { if (ShapeUtil::IsTuple(shape)) { - std::vector> elements; + std::vector elements; for (const Shape& element_shape : shape.tuple_shapes()) { TF_ASSIGN_OR_RETURN( - std::unique_ptr element, + Literal element, MakeFakeLiteralInternal(element_shape, engine, no_duplicates)); elements.push_back(std::move(element)); } @@ -131,60 +132,52 @@ StatusOr> MakeFakeLiteralInternal( if (engine == nullptr) { return Literal::CreateFromShape(shape); } - auto literal = absl::make_unique(shape); + Literal literal(shape); switch (shape.element_type()) { case BF16: - PopulateWithRandomFloatingPointData(literal.get(), engine, + PopulateWithRandomFloatingPointData(&literal, engine, no_duplicates); break; case F16: - PopulateWithRandomFloatingPointData(literal.get(), engine, + PopulateWithRandomFloatingPointData(&literal, engine, no_duplicates); break; case F32: - PopulateWithRandomFloatingPointData(literal.get(), engine, + PopulateWithRandomFloatingPointData(&literal, engine, no_duplicates); break; case F64: - PopulateWithRandomFloatingPointData(literal.get(), engine, + PopulateWithRandomFloatingPointData(&literal, engine, no_duplicates); break; case S8: - PopulateWithRandomIntegralData(literal.get(), engine, - no_duplicates); + PopulateWithRandomIntegralData(&literal, engine, no_duplicates); break; case U8: - PopulateWithRandomIntegralData(literal.get(), engine, - no_duplicates); + PopulateWithRandomIntegralData(&literal, engine, no_duplicates); break; case S16: - PopulateWithRandomIntegralData(literal.get(), engine, - no_duplicates); + PopulateWithRandomIntegralData(&literal, engine, no_duplicates); break; case U16: - PopulateWithRandomIntegralData(literal.get(), engine, - no_duplicates); + PopulateWithRandomIntegralData(&literal, engine, no_duplicates); break; case S32: - PopulateWithRandomIntegralData(literal.get(), engine, - no_duplicates); + PopulateWithRandomIntegralData(&literal, engine, no_duplicates); break; case U32: - PopulateWithRandomIntegralData(literal.get(), engine, - no_duplicates); + PopulateWithRandomIntegralData(&literal, engine, no_duplicates); break; case S64: - PopulateWithRandomIntegralData(literal.get(), engine, - no_duplicates); + PopulateWithRandomIntegralData(&literal, engine, no_duplicates); break; case U64: - PopulateWithRandomIntegralData(literal.get(), engine, - no_duplicates); + PopulateWithRandomIntegralData(&literal, engine, no_duplicates); break; case PRED: { std::uniform_int_distribution generator(0, 1); - TF_CHECK_OK(literal->Populate( - [&](tensorflow::gtl::ArraySlice /*indices*/) { + TF_CHECK_OK( + literal.Populate([&](absl::Span /*indices*/) { return generator(*engine); })); break; @@ -194,7 +187,7 @@ StatusOr> MakeFakeLiteralInternal( break; default: return Unimplemented("Unsupported type for fake literal generation: %s", - ShapeUtil::HumanString(shape).c_str()); + ShapeUtil::HumanString(shape)); } return std::move(literal); } @@ -203,6 +196,7 @@ enum class ConstantType { kUnknown, kZero, kOne }; // Return the constant type required by this computation, if known. ConstantType GetInitValue(const HloComputation& computation) { + // TODO(b/77635120): Add init values, for min, max, and their arg variants. const HloInstruction* const root = computation.root_instruction(); if (computation.num_parameters() != 2 || root->operand_count() != 2 || root->operand(0)->opcode() != HloOpcode::kParameter || @@ -227,16 +221,16 @@ bool NeedsInitValue(const HloUse& use) { const HloInstruction* const instruction = use.instruction; const HloOpcode opcode = instruction->opcode(); const int64 op_num = use.operand_number; - return ( - ((opcode == HloOpcode::kReduce || opcode == HloOpcode::kReduceWindow) && - op_num == 1) || - (opcode == HloOpcode::kSelectAndScatter && op_num == 2)); + return ((opcode == HloOpcode::kReduceWindow && op_num == 1) || + (opcode == HloOpcode::kSelectAndScatter && op_num == 2) || + (opcode == HloOpcode::kReduce && + op_num >= instruction->operand_count() / 2)); } // Generate random values that are constrained to the input_shape minus the // output_shape so as not to produce wrapping slices, for instance. -std::unique_ptr MakeRandomIndex( - tensorflow::gtl::ArraySlice index_space, std::minstd_rand0* engine) { +Literal MakeRandomIndex(absl::Span index_space, + std::minstd_rand0* engine) { std::vector start_indices(index_space.size()); if (engine != nullptr) { for (int i = 0; i < index_space.size(); ++i) { @@ -278,9 +272,11 @@ std::vector FindConstrainedUses( constrained_uses.insert(constrained_uses.end(), converted_uses.begin(), converted_uses.end()); } else if (opcode == HloOpcode::kSort && - instruction->operand_count() == 2 && op_num == 0) { + instruction->operand_count() >= 2 && op_num == 0) { // Operand 0 of sort is the array of keys used for key/value - // (two-operand) kSort instructions. + // (two-operand) kSort instructions. Since sort stability is not + // guaranteed, constrain keys of key-value sort not to have duplicates, + // since otherwise the value order may legitimately differ. constrained_uses.push_back(instruction); } } @@ -292,8 +288,8 @@ std::vector FindConstrainedUses( // no constrained uses in the dataflow graph. If such constraints exist, // generate a constrained literal (either bounded in the case of indices, or // zero in the case of init_values for reductions). -StatusOr> CreateLiteralForConstrainedUses( - const tensorflow::gtl::ArraySlice constrained_uses, +StatusOr CreateLiteralForConstrainedUses( + const absl::Span constrained_uses, const HloInstruction& param, std::minstd_rand0* engine) { std::vector index_space; bool no_duplicates = false; @@ -342,7 +338,7 @@ StatusOr> CreateLiteralForConstrainedUses( default: return Unimplemented( "Constrained operand generation not implemented for %s.", - use->ToString().c_str()); + use->ToString()); } } int constraint_count = 0; @@ -357,9 +353,9 @@ StatusOr> CreateLiteralForConstrainedUses( } else if (needs_constant) { switch (constant_type) { case ConstantType::kZero: - return LiteralUtil::Zero(param.shape().element_type()).CloneToUnique(); + return LiteralUtil::Zero(param.shape().element_type()); case ConstantType::kOne: - return LiteralUtil::One(param.shape().element_type()).CloneToUnique(); + return LiteralUtil::One(param.shape().element_type()); case ConstantType::kUnknown: // We want the identity element for the computation, but we don't really // know what it is - so any value we generate will be just as wrong. @@ -373,34 +369,33 @@ StatusOr> CreateLiteralForConstrainedUses( // Given a module entry parameter, use the dataflow analysis to see if a // special case literal must be created, or if we can generate fake data. -StatusOr> MakeConstrainedArgument( - const HloDataflowAnalysis& dataflow, const HloInstruction& param, - std::minstd_rand0* engine) { +StatusOr MakeConstrainedArgument(const HloDataflowAnalysis& dataflow, + const HloInstruction& param, + std::minstd_rand0* engine) { const auto constrained_uses = FindConstrainedUses(dataflow, param); return CreateLiteralForConstrainedUses(constrained_uses, param, engine); } } // namespace -StatusOr> MakeFakeLiteral(const Shape& shape, - bool pseudo_random) { +StatusOr MakeFakeLiteral(const Shape& shape, bool pseudo_random) { auto engine = pseudo_random ? absl::make_unique() : nullptr; return MakeFakeLiteralInternal(shape, engine.get(), /*no_duplicates=*/false); } -StatusOr>> MakeFakeArguments( - HloModule* const module, bool pseudo_random) { +StatusOr> MakeFakeArguments(HloModule* const module, + bool pseudo_random) { auto engine = pseudo_random ? absl::make_unique() : nullptr; return MakeFakeArguments(module, engine.get()); } -StatusOr>> MakeFakeArguments( - HloModule* const module, std::minstd_rand0* engine) { +StatusOr> MakeFakeArguments(HloModule* const module, + std::minstd_rand0* engine) { TF_ASSIGN_OR_RETURN(auto dataflow, HloDataflowAnalysis::Run(*module)); const auto params = module->entry_computation()->parameter_instructions(); - std::vector> arguments(params.size()); + std::vector arguments(params.size()); for (int i = 0; i < params.size(); ++i) { arguments[i] = MakeConstrainedArgument(*dataflow, *params[i], engine).ValueOrDie(); @@ -416,4 +411,18 @@ Status VerifyHloModule(HloModule* const module, bool layout_sensitive, .status(); } +std::unique_ptr CreateCanonicalDot(const Shape& shape, + HloInstruction* lhs, + HloInstruction* rhs) { + CHECK_EQ(ShapeUtil::Rank(lhs->shape()), 2); + CHECK_EQ(ShapeUtil::Rank(rhs->shape()), 2); + PrecisionConfig precision_config; + precision_config.mutable_operand_precision()->Resize( + 2, PrecisionConfig::DEFAULT); + DotDimensionNumbers dot_dimension_numbers; + dot_dimension_numbers.add_lhs_contracting_dimensions(1); + dot_dimension_numbers.add_rhs_contracting_dimensions(0); + return absl::make_unique( + shape, lhs, rhs, dot_dimension_numbers, precision_config); +} } // namespace xla diff --git a/tensorflow/compiler/xla/tests/test_utils.h b/tensorflow/compiler/xla/tests/test_utils.h index 277d53d4231d471897d4f0c47d297653ff5561d3..b3c8a739058475a4e51bae6ad2a98152a6532b47 100644 --- a/tensorflow/compiler/xla/tests/test_utils.h +++ b/tensorflow/compiler/xla/tests/test_utils.h @@ -21,13 +21,13 @@ limitations under the License. #include #include "absl/memory/memory.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/types.h" -#include "tensorflow/stream_executor/platform.h" namespace xla { @@ -57,8 +57,8 @@ class PseudorandomGenerator { // Generates fake data in a literal of the given shape, or returns an error // status if the element type is currently unhandled for fake data // generation. See below for documentation of pseudo_random. -StatusOr> MakeFakeLiteral(const Shape& shape, - bool pseudo_random = true); +StatusOr MakeFakeLiteral(const Shape& shape, + bool pseudo_random = true); // Generates a vector of arguments containing fake data. The number, shape and // layout of the arguments is appropriate for given HLO module. @@ -84,20 +84,26 @@ StatusOr> MakeFakeLiteral(const Shape& shape, // TODO(b/79942829): Make interesting argument generation fast enough that using // pseudo_random does not save any noticeable amount of time so that the // parameter can be removed. -StatusOr>> MakeFakeArguments( - HloModule* const module, bool pseudo_random = true); +StatusOr> MakeFakeArguments(HloModule* const module, + bool pseudo_random = true); // Overload which accepts a random number generator. This enables generation of // different random values with sequential calls to MakeFakeArguments by reusing // the same generator. -StatusOr>> MakeFakeArguments( - HloModule* const module, std::minstd_rand0* engine); +StatusOr> MakeFakeArguments(HloModule* const module, + std::minstd_rand0* engine); // Check that a given module satisfies various constraints before trying to // execute it. Status VerifyHloModule(HloModule* const module, bool layout_sensitive, bool allow_mixed_precision); +// Creates a dot op with operands 'lhs' and 'rhs' that contracts dimension 1 of +// the LHS with dimension 0 of the RHS with no batch dimensions. +// Both LHS and the RHS must be of rank 2. +std::unique_ptr CreateCanonicalDot(const Shape& shape, + HloInstruction* lhs, + HloInstruction* rhs); } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_TESTS_TEST_UTILS_H_ diff --git a/tensorflow/compiler/xla/tests/test_utils_test.cc b/tensorflow/compiler/xla/tests/test_utils_test.cc index 322c8ef090cf867f65cada5cb1dbae188f83bad6..bc433eac8fcb02087d8e4eb10f638c85dc141b22 100644 --- a/tensorflow/compiler/xla/tests/test_utils_test.cc +++ b/tensorflow/compiler/xla/tests/test_utils_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_utils.h" +#include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -85,10 +86,10 @@ XLA_TEST_F(TestUtilsTest, MultipleIndexSpacesForDynamicSlices) { ROOT dynamic-slice.2 = f32[3,2,2] dynamic-slice(array_param.2, index_param), dynamic_slice_sizes={3,2,2} })") .ValueOrDie(); - TF_ASSERT_OK_AND_ASSIGN(std::vector> args, + TF_ASSERT_OK_AND_ASSIGN(std::vector args, MakeFakeArguments(module.get())); ASSERT_EQ(args.size(), 3); - const Literal& index_arg = *args[0]; + const Literal& index_arg = args[0]; EXPECT_EQ(index_arg.Get({0}), 0); @@ -114,10 +115,10 @@ XLA_TEST_F(TestUtilsTest, MultipleIndexSpacesForDynamicUpdateSlices) { ROOT dynamic-update-slice.2 = f32[3,3000,5] dynamic-update-slice(array_param.2, update_param.2, index_param) })") .ValueOrDie(); - TF_ASSERT_OK_AND_ASSIGN(std::vector> args, + TF_ASSERT_OK_AND_ASSIGN(std::vector args, MakeFakeArguments(module.get())); ASSERT_EQ(args.size(), 5); - const Literal& index_arg = *args[0]; + const Literal& index_arg = args[0]; EXPECT_EQ(index_arg.Get({0}), 0); @@ -140,12 +141,12 @@ ENTRY %sort.148.1589 (parameter.0: f32[1048576], parameter.1: s32[1048576]) -> ( } )") .ValueOrDie(); - TF_ASSERT_OK_AND_ASSIGN(std::vector> args, + TF_ASSERT_OK_AND_ASSIGN(std::vector args, MakeFakeArguments(module.get())); ASSERT_EQ(args.size(), 2); - const Literal& key_arg = *args[0]; + const Literal& key_arg = args[0]; - tensorflow::gtl::FlatSet key_set; + absl::flat_hash_set key_set; for (const float& value : key_arg.data()) { EXPECT_TRUE(key_set.insert(tensorflow::bit_cast(value)).second); } @@ -163,12 +164,12 @@ ENTRY %sort.148.1589 (parameter.0: s32[1048576], parameter.1: s32[1048576]) -> ( } )") .ValueOrDie(); - TF_ASSERT_OK_AND_ASSIGN(std::vector> args, + TF_ASSERT_OK_AND_ASSIGN(std::vector args, MakeFakeArguments(module.get())); ASSERT_EQ(args.size(), 2); - const Literal& key_arg = *args[0]; + const Literal& key_arg = args[0]; - tensorflow::gtl::FlatSet key_set; + absl::flat_hash_set key_set; for (const int32& value : key_arg.data()) { EXPECT_TRUE(key_set.insert(tensorflow::bit_cast(value)).second); } diff --git a/tensorflow/compiler/xla/tests/token_hlo_test.cc b/tensorflow/compiler/xla/tests/token_hlo_test.cc index c7eb9e2dbe0e27b7933f5861280a3401cd268c08..b34fd0f2e873214c509533f29553af914ddc984d 100644 --- a/tensorflow/compiler/xla/tests/token_hlo_test.cc +++ b/tensorflow/compiler/xla/tests/token_hlo_test.cc @@ -34,9 +34,8 @@ XLA_TEST_F(TokenHloTest, SingleTokenInstruction) { module->AddEntryComputation(builder.Build()); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, - Execute(std::move(module), {})); - EXPECT_TRUE(LiteralTestUtil::Equal(*result, *LiteralUtil::CreateToken())); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Execute(std::move(module), {})); + EXPECT_TRUE(LiteralTestUtil::Equal(result, LiteralUtil::CreateToken())); } XLA_TEST_F(TokenHloTest, TokenTree) { @@ -50,9 +49,8 @@ XLA_TEST_F(TokenHloTest, TokenTree) { module->AddEntryComputation(builder.Build()); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, - Execute(std::move(module), {})); - EXPECT_TRUE(LiteralTestUtil::Equal(*result, *LiteralUtil::CreateToken())); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Execute(std::move(module), {})); + EXPECT_TRUE(LiteralTestUtil::Equal(result, LiteralUtil::CreateToken())); } XLA_TEST_F(TokenHloTest, InvalidTokenShapedEntryParameter) { @@ -193,9 +191,8 @@ ENTRY %TokenInConditional (param.3: pred[]) -> s32[] { std::unique_ptr module, HloRunner::CreateModuleFromString(module_string, debug_options)); auto arg = LiteralUtil::CreateR0(true); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, - Execute(std::move(module), {arg.get()})); - EXPECT_EQ(42, result->Get({})); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Execute(std::move(module), {&arg})); + EXPECT_EQ(42, result.Get({})); } { @@ -204,9 +201,8 @@ ENTRY %TokenInConditional (param.3: pred[]) -> s32[] { std::unique_ptr module, HloRunner::CreateModuleFromString(module_string, debug_options)); auto arg = LiteralUtil::CreateR0(false); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, - Execute(std::move(module), {arg.get()})); - EXPECT_EQ(7, result->Get({})); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Execute(std::move(module), {&arg})); + EXPECT_EQ(7, result.Get({})); } } diff --git a/tensorflow/compiler/xla/tests/transfer_manager_test.cc b/tensorflow/compiler/xla/tests/transfer_manager_test.cc index 125513ddfd16cb4e742e7d589e22b721307621ee..d6641d257a75945be94d299a1bd4b0366e3759b7 100644 --- a/tensorflow/compiler/xla/tests/transfer_manager_test.cc +++ b/tensorflow/compiler/xla/tests/transfer_manager_test.cc @@ -69,90 +69,90 @@ class TransferManagerTest : public LocalClientTestBase { }; XLA_TEST_F(TransferManagerTest, TransferR0U32) { - std::unique_ptr literal = LiteralUtil::CreateR0(42); - const Shape& shape = literal->shape(); + Literal literal = LiteralUtil::CreateR0(42); + const Shape& shape = literal.shape(); auto device_buffer = AllocateDeviceBuffer(shape); // Round trip literal through device. - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal, device_buffer)); TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr result, + Literal result, transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); - LiteralTestUtil::ExpectR0Equal(42, *result); + LiteralTestUtil::ExpectR0Equal(42, result); } XLA_TEST_F(TransferManagerTest, TransferR1F32) { - std::unique_ptr literal = + Literal literal = LiteralUtil::CreateR1({1.25f, 2.5f, -17.0f, -20.125f}); - const Shape& shape = literal->shape(); + const Shape& shape = literal.shape(); auto device_buffer = AllocateDeviceBuffer(shape); // Round trip literal through device. - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal, device_buffer)); TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr result, + Literal result, transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); LiteralTestUtil::ExpectR1Equal({1.25f, 2.5f, -17.0f, -20.125f}, - *result); + result); } XLA_TEST_F(TransferManagerTest, TransferR1LargeF32) { std::vector test_vector(1024 * 1024); std::iota(test_vector.begin(), test_vector.end(), 0); - std::unique_ptr literal = LiteralUtil::CreateR1(test_vector); - const Shape& shape = literal->shape(); + Literal literal = LiteralUtil::CreateR1(test_vector); + const Shape& shape = literal.shape(); auto device_buffer = AllocateDeviceBuffer(shape); // Round trip literal through device. - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal, device_buffer)); TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr result, + Literal result, transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); - LiteralTestUtil::ExpectR1Equal(test_vector, *result); + LiteralTestUtil::ExpectR1Equal(test_vector, result); } XLA_TEST_F(TransferManagerTest, TransferR1U8) { const char* test_string = "0123456789abcdef"; - std::unique_ptr literal = LiteralUtil::CreateR1U8(test_string); - const Shape& shape = literal->shape(); + Literal literal = LiteralUtil::CreateR1U8(test_string); + const Shape& shape = literal.shape(); auto device_buffer = AllocateDeviceBuffer(shape); // Round trip literal through device. - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal, device_buffer)); TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr result, + Literal result, transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); - EXPECT_EQ(result->GetR1U8AsString(), test_string); + EXPECT_EQ(result.GetR1U8AsString(), test_string); } XLA_TEST_F(TransferManagerTest, TransferR2F32) { - std::unique_ptr literal = + Literal literal = LiteralUtil::CreateR2({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}); - const Shape& shape = literal->shape(); + const Shape& shape = literal.shape(); auto device_buffer = AllocateDeviceBuffer(shape); // Round trip literal through device. - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal, device_buffer)); TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr result, + Literal result, transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); LiteralTestUtil::ExpectR2Equal( - {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, *result); + {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, result); } XLA_TEST_F(TransferManagerTest, TransferR2F32AndChangeLayoutTransferringToDevice) { - std::unique_ptr literal = LiteralUtil::CreateR2WithLayout( + Literal literal = LiteralUtil::CreateR2WithLayout( {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, LayoutUtil::MakeLayout({0, 1})); const Shape ondevice_shape = ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {1, 0}); @@ -160,101 +160,99 @@ XLA_TEST_F(TransferManagerTest, // Round trip literal through device. Set the on-device layout to something // different than the literal layout. - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal, device_buffer)); TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr result, + Literal result, transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); EXPECT_FALSE( - LayoutUtil::Equal(result->shape().layout(), literal->shape().layout())); + LayoutUtil::Equal(result.shape().layout(), literal.shape().layout())); LiteralTestUtil::ExpectR2Equal( - {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, *result); + {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, result); } XLA_TEST_F(TransferManagerTest, TransferTuple) { - std::unique_ptr literal = LiteralUtil::MakeTuple( - {LiteralUtil::CreateR0(123.0f).get(), - LiteralUtil::CreateR2({{1.0f, 2.0f}, {4.0f, 5.0f}}).get(), - LiteralUtil::CreateR1({44.0f, -10.0f, 3333333.3f}).get()}); - auto device_buffer = AllocateDeviceBuffer(literal->shape()); + Literal literal = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR0(123.0f), + LiteralUtil::CreateR2({{1.0f, 2.0f}, {4.0f, 5.0f}}), + LiteralUtil::CreateR1({44.0f, -10.0f, 3333333.3f})}); + auto device_buffer = AllocateDeviceBuffer(literal.shape()); // Round trip literal through device. - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal, device_buffer)); TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr result, + Literal result, transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); - EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(literal, result)); } XLA_TEST_F(TransferManagerTest, TransferEmptyTuple) { - std::unique_ptr literal = LiteralUtil::MakeTuple({}); - auto device_buffer = AllocateDeviceBuffer(literal->shape()); + Literal literal = LiteralUtil::MakeTuple({}); + auto device_buffer = AllocateDeviceBuffer(literal.shape()); // Round trip literal through device. - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal, device_buffer)); TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr result, + Literal result, transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); - EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(literal, result)); } XLA_TEST_F(TransferManagerTest, TransferNestedTuple) { - std::unique_ptr literal = LiteralUtil::MakeTuple( - {LiteralUtil::CreateR0(123.0f).get(), - LiteralUtil::MakeTuple( - {LiteralUtil::CreateR2({{1.0f, 2.0f}, {4.0f, 5.0f}}).get(), - LiteralUtil::CreateR1({44.0f, -10.0f, 3333333.3f}).get()}) - .get(), - LiteralUtil::CreateR1({-10.0f, 123.0f}).get()}); - auto device_buffer = AllocateDeviceBuffer(literal->shape()); + Literal literal = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR0(123.0f), + LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR2({{1.0f, 2.0f}, {4.0f, 5.0f}}), + LiteralUtil::CreateR1({44.0f, -10.0f, 3333333.3f})}), + LiteralUtil::CreateR1({-10.0f, 123.0f})}); + auto device_buffer = AllocateDeviceBuffer(literal.shape()); // Round trip literal through device. - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal, device_buffer)); TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr result, + Literal result, transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); - EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(literal, result)); } XLA_TEST_F(TransferManagerTest, TransferComplexValue) { - std::unique_ptr literal = LiteralUtil::CreateR1( + Literal literal = LiteralUtil::CreateR1( {complex64(1.0f, 2.0f), complex64(42.0f, -123.4f)}); - auto device_buffer = AllocateDeviceBuffer(literal->shape()); + auto device_buffer = AllocateDeviceBuffer(literal.shape()); // Round trip literal through device. - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal, device_buffer)); TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr result, + Literal result, transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); - EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(literal, result)); } XLA_TEST_F(TransferManagerTest, TransferComplexValueInTuple) { - std::unique_ptr literal = LiteralUtil::MakeTuple( + Literal literal = LiteralUtil::MakeTupleFromSlices( {LiteralUtil::CreateR1( - {complex64(1.0f, 2.0f), complex64(42.0f, -123.4f)}) - .get(), - LiteralUtil::CreateR1({1, 2, 3, 4, 5, 6}).get(), - LiteralUtil::CreateR0(complex64(0.3f, -0.4f)).get()}); - auto device_buffer = AllocateDeviceBuffer(literal->shape()); + {complex64(1.0f, 2.0f), complex64(42.0f, -123.4f)}), + LiteralUtil::CreateR1({1, 2, 3, 4, 5, 6}), + LiteralUtil::CreateR0(complex64(0.3f, -0.4f))}); + auto device_buffer = AllocateDeviceBuffer(literal.shape()); // Round trip literal through device. - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal, device_buffer)); TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr result, + Literal result, transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); - EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(literal, result)); } XLA_TEST_F(TransferManagerTest, TransferTokenFromDevice) { @@ -264,54 +262,52 @@ XLA_TEST_F(TransferManagerTest, TransferTokenFromDevice) { // supported. auto device_buffer = AllocateDeviceBuffer(ShapeUtil::MakeTokenShape()); TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr result, + Literal result, transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); - EXPECT_TRUE(LiteralTestUtil::Equal(*LiteralUtil::CreateToken(), *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateToken(), result)); } XLA_TEST_F(TransferManagerTest, MultiStreamRoundTripSoak) { const int64 kIterationCount = 5000; - std::unique_ptr literal1 = LiteralUtil::MakeTuple( - {LiteralUtil::CreateR0(123.0f).get(), - LiteralUtil::MakeTuple( - {LiteralUtil::CreateR2({{1.0f, 2.0f}, {4.0f, 5.0f}}).get(), - LiteralUtil::CreateR1({44.0f, -10.0f, 3333333.3f}).get()}) - .get(), - LiteralUtil::CreateR1({-10.0f, 123.0f}).get()}); - std::unique_ptr literal2 = LiteralUtil::MakeTuple( - {LiteralUtil::CreateR0(456.0f).get(), - LiteralUtil::MakeTuple( - {LiteralUtil::CreateR2({{5.0f, 7.0f}, {9.0f, 4.0f}}).get(), - LiteralUtil::CreateR1({44.0f, -11.0f, 3333333.3f}).get()}) - .get(), - LiteralUtil::CreateR1({-98.0f, 153.0f}).get()}); - - auto device_buffer1 = AllocateDeviceBuffer(literal1->shape()); - auto device_buffer2 = AllocateDeviceBuffer(literal2->shape()); + Literal literal1 = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR0(123.0f), + LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR2({{1.0f, 2.0f}, {4.0f, 5.0f}}), + LiteralUtil::CreateR1({44.0f, -10.0f, 3333333.3f})}), + LiteralUtil::CreateR1({-10.0f, 123.0f})}); + Literal literal2 = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR0(456.0f), + LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR2({{5.0f, 7.0f}, {9.0f, 4.0f}}), + LiteralUtil::CreateR1({44.0f, -11.0f, 3333333.3f})}), + LiteralUtil::CreateR1({-98.0f, 153.0f})}); + + auto device_buffer1 = AllocateDeviceBuffer(literal1.shape()); + auto device_buffer2 = AllocateDeviceBuffer(literal2.shape()); auto stream1 = stream_; auto stream2 = stream_->GetOrCreateSubStream(); - std::unique_ptr result1, result2; + Literal result1, result2; // Round trip literals through device in multiple streams asynchronously. for (int i = 0; i < kIterationCount; ++i) { - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream1, *literal1, + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream1, literal1, device_buffer1)); - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream2, *literal2, + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream2, literal2, device_buffer2)); TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr this_result1, + Literal this_result1, transfer_manager_->TransferLiteralFromDevice(stream1, device_buffer1)); TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr this_result2, + Literal this_result2, transfer_manager_->TransferLiteralFromDevice(stream2, device_buffer2)); result1 = std::move(this_result1); result2 = std::move(this_result2); } - EXPECT_TRUE(LiteralTestUtil::Equal(*literal1, *result1)); - EXPECT_TRUE(LiteralTestUtil::Equal(*literal2, *result2)); + EXPECT_TRUE(LiteralTestUtil::Equal(literal1, result1)); + EXPECT_TRUE(LiteralTestUtil::Equal(literal2, result2)); } class TransferDeviceToHostBenchmark : public TransferManagerTest { @@ -323,20 +319,19 @@ class TransferDeviceToHostBenchmark : public TransferManagerTest { tensorflow::testing::StopTiming(); SetUp(); - std::vector> tuple_elements; + std::vector tuple_elements; for (int i = 0; i < num_tuple_elements; ++i) { tuple_elements.push_back( LiteralUtil::CreateR2F32Linspace(0.0f, 1.0f, array_size, array_size)); } - std::unique_ptr literal = - LiteralUtil::MakeTupleOwned(std::move(tuple_elements)); - auto device_buffer = AllocateDeviceBuffer(literal->shape()); - TF_CHECK_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + Literal literal = LiteralUtil::MakeTupleOwned(std::move(tuple_elements)); + auto device_buffer = AllocateDeviceBuffer(literal.shape()); + TF_CHECK_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal, device_buffer)); tensorflow::testing::StartTiming(); for (int i = 0; i < iters; ++i) { TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr result, + Literal result, transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); } tensorflow::testing::StopTiming(); @@ -355,17 +350,16 @@ class TransferHostToDeviceBenchmark : public TransferManagerTest { tensorflow::testing::StopTiming(); SetUp(); - std::vector> tuple_elements; + std::vector tuple_elements; for (int i = 0; i < num_tuple_elements; ++i) { tuple_elements.push_back( LiteralUtil::CreateR2F32Linspace(0.0f, 1.0f, array_size, array_size)); } - std::unique_ptr literal = - LiteralUtil::MakeTupleOwned(std::move(tuple_elements)); - auto device_buffer = AllocateDeviceBuffer(literal->shape()); + Literal literal = LiteralUtil::MakeTupleOwned(std::move(tuple_elements)); + auto device_buffer = AllocateDeviceBuffer(literal.shape()); tensorflow::testing::StartTiming(); for (int i = 0; i < iters; ++i) { - TF_CHECK_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + TF_CHECK_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal, device_buffer)); } tensorflow::testing::StopTiming(); diff --git a/tensorflow/compiler/xla/tests/tuple_test.cc b/tensorflow/compiler/xla/tests/tuple_test.cc index c101cd2d20131199801f755c96b629ccb65744db..619d2a388b5646c31f0a61f709a2ab3067e39c03 100644 --- a/tensorflow/compiler/xla/tests/tuple_test.cc +++ b/tensorflow/compiler/xla/tests/tuple_test.cc @@ -51,13 +51,13 @@ XLA_TEST_F(TupleTest, TupleConstant) { {1.1f, 2.2f, 3.5f}, // row 0 {4.8f, 5.0f, 6.7f}, // row 1 }; - auto value = LiteralUtil::MakeTuple( - {LiteralUtil::CreateR0(constant_scalar).get(), - LiteralUtil::CreateR1(constant_vector).get(), - LiteralUtil::CreateR2(constant_matrix).get()}); + auto value = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR0(constant_scalar), + LiteralUtil::CreateR1(constant_vector), + LiteralUtil::CreateR2(constant_matrix)}); - ConstantLiteral(&builder, *value); - ComputeAndCompareTuple(&builder, *value, {}, error_spec_); + ConstantLiteral(&builder, value); + ComputeAndCompareTuple(&builder, value, {}, error_spec_); } // Tests a tuple made of scalar constants. @@ -66,12 +66,12 @@ XLA_TEST_F(TupleTest, TupleScalarConstant) { const float constant_scalar1 = 7.3f; const float constant_scalar2 = 1.2f; - auto value = LiteralUtil::MakeTuple( - {LiteralUtil::CreateR0(constant_scalar1).get(), - LiteralUtil::CreateR0(constant_scalar2).get()}); + auto value = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR0(constant_scalar1), + LiteralUtil::CreateR0(constant_scalar2)}); - ConstantLiteral(&builder, *value); - ComputeAndCompareTuple(&builder, *value, {}, error_spec_); + ConstantLiteral(&builder, value); + ComputeAndCompareTuple(&builder, value, {}, error_spec_); } // Tests the creation of tuple data. @@ -88,11 +88,11 @@ XLA_TEST_F(TupleTest, TupleCreate) { ConstantR1(&builder, constant_vector), ConstantR2(&builder, constant_matrix)}); - auto expected = LiteralUtil::MakeTuple( - {LiteralUtil::CreateR0(constant_scalar).get(), - LiteralUtil::CreateR1(constant_vector).get(), - LiteralUtil::CreateR2(constant_matrix).get()}); - ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); + auto expected = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR0(constant_scalar), + LiteralUtil::CreateR1(constant_vector), + LiteralUtil::CreateR2(constant_matrix)}); + ComputeAndCompareTuple(&builder, expected, {}, error_spec_); } // Tests the creation of tuple data. @@ -102,10 +102,9 @@ XLA_TEST_F(TupleTest, TupleCreateWithZeroElementEntry) { Tuple(&builder, {ConstantR0(&builder, 7.0), ConstantR1(&builder, {})}); - auto expected = - LiteralUtil::MakeTuple({LiteralUtil::CreateR0(7.0).get(), - LiteralUtil::CreateR1({}).get()}); - ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); + auto expected = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR0(7.0), LiteralUtil::CreateR1({})}); + ComputeAndCompareTuple(&builder, expected, {}, error_spec_); } // Tests the creation of an empty tuple. @@ -113,7 +112,7 @@ XLA_TEST_F(TupleTest, EmptyTupleCreate) { XlaBuilder builder(TestName()); Tuple(&builder, {}); auto expected = LiteralUtil::MakeTuple({}); - ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); + ComputeAndCompareTuple(&builder, expected, {}, error_spec_); } // Trivial test for extracting a tuple element with GetTupleElement. @@ -196,10 +195,10 @@ XLA_TEST_F(TupleTest, TupleGTEToTuple) { ConstantR2(&builder, constant_matrix)}); Tuple(&builder, {GetTupleElement(tuple_data, 1), GetTupleElement(tuple_data, 0)}); - auto expected = LiteralUtil::MakeTuple( - {LiteralUtil::CreateR2(constant_matrix).get(), - LiteralUtil::CreateR1(constant_vector).get()}); - ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); + auto expected = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR2(constant_matrix), + LiteralUtil::CreateR1(constant_vector)}); + ComputeAndCompareTuple(&builder, expected, {}, error_spec_); } XLA_TEST_F(TupleTest, SelectBetweenPredTuples) { @@ -218,11 +217,11 @@ XLA_TEST_F(TupleTest, SelectBetweenPredTuples) { auto v1_v2 = Tuple(&b, {v1_gt, v2_gt}); // {false, true} auto v2_v1 = Tuple(&b, {v2_gt, v1_gt}); // {true, false} Select(direction ? v1_gt : v2_gt, v1_v2, v2_v1); - auto expected = - LiteralUtil::MakeTuple({LiteralUtil::CreateR0(direction).get(), - LiteralUtil::CreateR0(!direction).get()}); + auto expected = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR0(direction), + LiteralUtil::CreateR0(!direction)}); - ComputeAndCompareTuple(&b, *expected, {v1_data.get(), v2_data.get()}, + ComputeAndCompareTuple(&b, expected, {v1_data.get(), v2_data.get()}, error_spec_); } } @@ -287,10 +286,9 @@ XLA_TEST_F(TupleTest, SelectBetweenTuplesOnFalse) { ConstantR1(&builder, vec1)}); Select(ConstantR0(&builder, false), tuple12, tuple21); - auto expected = - LiteralUtil::MakeTuple({LiteralUtil::CreateR1(vec2).get(), - LiteralUtil::CreateR1(vec1).get()}); - ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); + auto expected = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR1(vec2), LiteralUtil::CreateR1(vec1)}); + ComputeAndCompareTuple(&builder, expected, {}, error_spec_); } XLA_TEST_F(TupleTest, TuplesInAMap) { @@ -332,10 +330,9 @@ XLA_TEST_F(TupleTest, SelectBetweenTuplesOnTrue) { ConstantR1(&builder, vec1)}); Select(ConstantR0(&builder, true), tuple12, tuple21); - auto expected = - LiteralUtil::MakeTuple({LiteralUtil::CreateR1(vec1).get(), - LiteralUtil::CreateR1(vec2).get()}); - ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); + auto expected = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR1(vec1), LiteralUtil::CreateR1(vec2)}); + ComputeAndCompareTuple(&builder, expected, {}, error_spec_); } XLA_TEST_F(TupleTest, SelectBetweenTuplesElementResult) { @@ -408,10 +405,9 @@ XLA_TEST_F(TupleTest, SelectBetweenTuplesReuseConstants) { Select(ConstantR0(&builder, false), tuple12, tuple21); - auto expected = - LiteralUtil::MakeTuple({LiteralUtil::CreateR1(vec2).get(), - LiteralUtil::CreateR1(vec1).get()}); - ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); + auto expected = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR1(vec2), LiteralUtil::CreateR1(vec1)}); + ComputeAndCompareTuple(&builder, expected, {}, error_spec_); } XLA_TEST_F(TupleTest, NestedTuples) { @@ -423,12 +419,11 @@ XLA_TEST_F(TupleTest, NestedTuples) { auto expected_v1 = LiteralUtil::CreateR1({1.0, 2.0}); auto expected_s = LiteralUtil::CreateR0(42.0); auto expected_inner_tuple = - LiteralUtil::MakeTuple({expected_v1.get(), expected_s.get()}); + LiteralUtil::MakeTuple({&expected_v1, &expected_s}); auto expected_v2 = LiteralUtil::CreateR1({22.0, 44.0}); - auto expected = - LiteralUtil::MakeTuple({expected_inner_tuple.get(), expected_v2.get()}); + auto expected = LiteralUtil::MakeTuple({&expected_inner_tuple, &expected_v2}); - ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); + ComputeAndCompareTuple(&builder, expected, {}, error_spec_); } XLA_TEST_F(TupleTest, GetTupleElementOfNestedTuple) { @@ -446,14 +441,12 @@ XLA_TEST_F(TupleTest, GetTupleElementOfNestedTuple) { std::unique_ptr data = client_ - ->TransferToServer(*LiteralUtil::MakeTuple({ - LiteralUtil::MakeTuple( - { - LiteralUtil::CreateR1({1.0, 2.0, 3.0}).get(), - LiteralUtil::CreateR1({4.0, 5.0, 6.0}).get(), - }) - .get(), - LiteralUtil::CreateR1({7.0, 8.0, 9.0}).get(), + ->TransferToServer(LiteralUtil::MakeTupleFromSlices({ + LiteralUtil::MakeTupleFromSlices({ + LiteralUtil::CreateR1({1.0, 2.0, 3.0}), + LiteralUtil::CreateR1({4.0, 5.0, 6.0}), + }), + LiteralUtil::CreateR1({7.0, 8.0, 9.0}), })) .ConsumeValueOrDie(); @@ -484,40 +477,36 @@ XLA_TEST_F(TupleTest, ComplexTuples) { std::unique_ptr arg0 = client_ - ->TransferToServer(*LiteralUtil::MakeTuple( - {LiteralUtil::CreateR0({1, 2}).get(), - LiteralUtil::MakeTuple( - {LiteralUtil::CreateR1({{10, 20}, {30, 40}}) - .get(), + ->TransferToServer(LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR0({1, 2}), + LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR1({{10, 20}, {30, 40}}), LiteralUtil::CreateR2( {{{100, 200}, {300, 400}}, {{1000, 2000}, {3000, 4000}}, - {{10000, 20000}, {30000, 40000}}}) - .get()}) - .get()})) + {{10000, 20000}, {30000, 40000}}})})})) .ConsumeValueOrDie(); std::unique_ptr arg1 = client_ ->TransferToServer( - *LiteralUtil::CreateR1({{1, 2}, {1, -2}})) + LiteralUtil::CreateR1({{1, 2}, {1, -2}})) .ConsumeValueOrDie(); auto sum = LiteralUtil::CreateR2({{{111, 222}, {331, 442}}, {{1011, 2022}, {3031, 4042}}, {{10011, 20022}, {30031, 40042}}}); - auto prod = absl::make_unique(sum->shape()); - ASSERT_TRUE(prod->Populate( - [&sum](tensorflow::gtl::ArraySlice indexes) { - return sum->Get(indexes) * - (indexes[indexes.size() - 1] == 0 - ? complex64(1, 2) - : complex64(1, -2)); - }) + Literal prod(sum.shape()); + ASSERT_TRUE(prod.Populate([&sum](absl::Span indexes) { + return sum.Get(indexes) * + (indexes[indexes.size() - 1] == 0 + ? complex64(1, 2) + : complex64(1, -2)); + }) .ok()); - auto expected = LiteralUtil::MakeTuple( - {LiteralUtil::MakeTuple({prod.get(), sum.get()}).get(), - LiteralUtil::CreateR0({123, 456}).get()}); - ComputeAndCompareTuple(&builder, *expected, {arg0.get(), arg1.get()}, + auto expected = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::MakeTupleFromSlices({prod, sum}), + LiteralUtil::CreateR0({123, 456})}); + ComputeAndCompareTuple(&builder, expected, {arg0.get(), arg1.get()}, error_spec_); } @@ -541,10 +530,10 @@ XLA_TEST_F(TupleHloTest, DISABLED_ON_INTERPRETER(BitcastAfterGTE)) { .ValueOrDie(); auto param = LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1({1, 2, 3})); - auto result = ExecuteNoHloPasses(std::move(module), {param.get()}); + auto result = ExecuteNoHloPasses(std::move(module), {¶m}); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR2({{1, 2, 3}})), - *result)); + LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR2({{1, 2, 3}})), + result)); } // Disabled on interpreter due to lack of outfeed. @@ -581,16 +570,15 @@ XLA_TEST_F(TupleHloTest, tensorflow::Env::Default()->StartThread( tensorflow::ThreadOptions(), "execute_thread", [&] { TF_EXPECT_OK(Execute(std::move(module), - {param0.get(), param1.get(), param1.get(), - param0.get(), param4.get()}) + {¶m0, ¶m1, ¶m1, ¶m0, ¶m4}) .status()); })); auto expected = LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1({2, 3})); - auto literal = Literal::CreateFromShape(expected->shape()); + auto literal = Literal::CreateFromShape(expected.shape()); TF_EXPECT_OK(backend().transfer_manager()->TransferLiteralFromOutfeed( - backend().default_stream_executor(), expected->shape(), *literal)); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *literal)); + backend().default_stream_executor(), expected.shape(), literal)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, literal)); } } // namespace diff --git a/tensorflow/compiler/xla/tests/unary_op_test.cc b/tensorflow/compiler/xla/tests/unary_op_test.cc index 8f80a9f3e466d73f2b718452d9a0d64a80c3b36f..4fbd7f2fb174ac899c1e3b23801986cb52db96a2 100644 --- a/tensorflow/compiler/xla/tests/unary_op_test.cc +++ b/tensorflow/compiler/xla/tests/unary_op_test.cc @@ -100,9 +100,9 @@ void UnaryOpTest::AbsTestHelper() { {-inf(), 0}}); Abs(arg); - std::unique_ptr expected = + Literal expected = LiteralUtil::CreateR1({2, 25, 0, 0.5, inf(), inf()}); - ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6f)); + ComputeAndCompareLiteral(&builder, expected, {}, ErrorSpec(1e-6f)); } template <> @@ -113,9 +113,9 @@ void UnaryOpTest::SignTestHelper() { {{-2, 0}, {0, 25}, {0, 0}, {static_cast(-0.0), 0}, {-1, 1}}); Sign(arg); - std::unique_ptr expected = LiteralUtil::CreateR1( + Literal expected = LiteralUtil::CreateR1( {{-1, 0}, {0, 1}, {0, 0}, {0, 0}, {-std::sqrt(0.5f), std::sqrt(0.5f)}}); - ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6f)); + ComputeAndCompareLiteral(&builder, expected, {}, ErrorSpec(1e-6f)); } template <> @@ -127,9 +127,8 @@ void UnaryOpTest::SignAbsTestHelper() { auto abs = Abs(arg); Sub(Mul(sign, ConvertElementType(abs, C64)), arg); - std::unique_ptr expected = - LiteralUtil::CreateR1({0, 0, 0, 0}); - ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6f)); + Literal expected = LiteralUtil::CreateR1({0, 0, 0, 0}); + ComputeAndCompareLiteral(&builder, expected, {}, ErrorSpec(1e-6f)); } XLA_TEST_F(UnaryOpTest, AbsTestR1Size0) { @@ -172,9 +171,8 @@ XLA_TEST_F(UnaryOpTest, SignTestR0) { Add(sgnc, ConvertElementType( Add(Add(sgnf0, sgnf), ConvertElementType(sgni, F32)), C64)); - std::unique_ptr expected = - LiteralUtil::CreateR0({-2.6f, 0.8f}); - ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6f)); + Literal expected = LiteralUtil::CreateR0({-2.6f, 0.8f}); + ComputeAndCompareLiteral(&builder, expected, {}, ErrorSpec(1e-6f)); } XLA_TEST_F(UnaryOpTest, SignTestR1) { diff --git a/tensorflow/compiler/xla/tests/while_test.cc b/tensorflow/compiler/xla/tests/while_test.cc index 1bdf1867b9330b715b0ba4aca71d56307883c775..6d5f276e82087cedc356691b0ff08df24cec8d20 100644 --- a/tensorflow/compiler/xla/tests/while_test.cc +++ b/tensorflow/compiler/xla/tests/while_test.cc @@ -48,7 +48,7 @@ class WhileTest : public ClientLibraryTestBase {}; // while (result < 5) { // result = result + 1; // } -TEST_F(WhileTest, WhileWithScalarS32Result) { +XLA_TEST_F(WhileTest, WhileWithScalarS32Result) { auto result_shape = ShapeUtil::MakeShape(S32, {}); // Create a computation for the condition: repeat for 5 iterations. @@ -84,7 +84,7 @@ TEST_F(WhileTest, WhileWithScalarS32Result) { // while (result < 5) { // result = result + 1; // } -TEST_F(WhileTest, WhileWithScalarS64Result) { +XLA_TEST_F(WhileTest, WhileWithScalarS64Result) { auto result_shape = ShapeUtil::MakeShape(S64, {}); // Create a computation for the condition: repeat for 5 iterations. @@ -114,7 +114,7 @@ TEST_F(WhileTest, WhileWithScalarS64Result) { ComputeAndCompareR0(&builder, 5, {}); } -TEST_F(WhileTest, WhileWithScalarResultNonConstInit) { +XLA_TEST_F(WhileTest, WhileWithScalarResultNonConstInit) { auto result_shape = ShapeUtil::MakeShape(S32, {}); auto orig_shape = ShapeUtil::MakeShape(S32, {2}); @@ -147,7 +147,7 @@ TEST_F(WhileTest, WhileWithScalarResultNonConstInit) { ComputeAndCompareR0(&builder, 5, {}); } -TEST_F(WhileTest, WhileWithPredicateResult) { +XLA_TEST_F(WhileTest, WhileWithPredicateResult) { auto result_shape = ShapeUtil::MakeShape(PRED, {}); // Create a computation for the condition: run until condition is true. @@ -184,7 +184,7 @@ TEST_F(WhileTest, WhileWithPredicateResult) { // while (result.sum() < 15.5f) { // result = result + vector(0); // } -TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileWithEmptyVectorResult)) { +XLA_TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileWithEmptyVectorResult)) { Shape result_shape = ShapeUtil::MakeShape(F32, {0}); // Create a computation for the reduction. @@ -238,7 +238,7 @@ TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileWithEmptyVectorResult)) { // while (result.sum() < 15.5f) { // result = result + vector(8, 0.125f); // } -TEST_F(WhileTest, WhileWithVectorResult) { +XLA_TEST_F(WhileTest, WhileWithVectorResult) { Shape result_shape = ShapeUtil::MakeShape(F32, {8}); // Create a computation for the reduction. @@ -298,7 +298,7 @@ TEST_F(WhileTest, WhileWithVectorResult) { // result = result + vector(8, 0.125f); // } // tuple = tuple { while } -TEST_F(WhileTest, WhileWithVectorResultIntoTuple) { +XLA_TEST_F(WhileTest, WhileWithVectorResultIntoTuple) { Shape result_shape = ShapeUtil::MakeShape(F32, {8}); // Create a computation for the reduction. @@ -348,12 +348,12 @@ TEST_F(WhileTest, WhileWithVectorResultIntoTuple) { // have all reached 2.0. auto expected_data = LiteralUtil::CreateR1({2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f}); - auto expected = LiteralUtil::MakeTuple({expected_data.get()}); - VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape()); - ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001)); + auto expected = LiteralUtil::MakeTuple({&expected_data}); + VLOG(2) << "expected = " << ShapeUtil::HumanString(expected.shape()); + ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.0001)); } -TEST_F(WhileTest, WhileWithPermutationAndTupleResult) { +XLA_TEST_F(WhileTest, WhileWithPermutationAndTupleResult) { std::vector shape_elements = { ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(F32, {3}), ShapeUtil::MakeShape(F32, {3}), ShapeUtil::MakeShape(F32, {3})}; @@ -401,14 +401,13 @@ TEST_F(WhileTest, WhileWithPermutationAndTupleResult) { auto expected_w1 = LiteralUtil::CreateR1({1.0f, 1.0f, 1.0f}); auto expected_w2 = LiteralUtil::CreateR1({2.0f, 2.0f, 2.0f}); auto expected_w3 = LiteralUtil::CreateR1({3.0f, 3.0f, 3.0f}); - auto expected = - LiteralUtil::MakeTuple({expected_counter.get(), expected_w2.get(), - expected_w3.get(), expected_w1.get()}); - VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape()); - ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001)); + auto expected = LiteralUtil::MakeTuple( + {&expected_counter, &expected_w2, &expected_w3, &expected_w1}); + VLOG(2) << "expected = " << ShapeUtil::HumanString(expected.shape()); + ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.0001)); } -TEST_F(WhileTest, WhileWithPermutationAndVectorResult) { +XLA_TEST_F(WhileTest, WhileWithPermutationAndVectorResult) { std::vector shape_elements = { ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(F32, {3}), ShapeUtil::MakeShape(F32, {3}), ShapeUtil::MakeShape(F32, {3})}; @@ -466,7 +465,7 @@ TEST_F(WhileTest, WhileWithPermutationAndVectorResult) { // get<0>(result) = get<0>(result) + 1; // get<1>(result) = get<1>(result) + vector(10, 1.0f); // } -TEST_F(WhileTest, WhileWithTupleResult) { +XLA_TEST_F(WhileTest, WhileWithTupleResult) { std::vector shape_elements = {ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(F32, {10})}; Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements); @@ -510,13 +509,12 @@ TEST_F(WhileTest, WhileWithTupleResult) { auto expected_counter = LiteralUtil::CreateR0(5); auto expected_data = LiteralUtil::CreateR1( {5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f}); - auto expected = - LiteralUtil::MakeTuple({expected_counter.get(), expected_data.get()}); - VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape()); - ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001)); + auto expected = LiteralUtil::MakeTuple({&expected_counter, &expected_data}); + VLOG(2) << "expected = " << ShapeUtil::HumanString(expected.shape()); + ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.0001)); } -TEST_F(WhileTest, WhileWithPredicateTupleResult) { +XLA_TEST_F(WhileTest, WhileWithPredicateTupleResult) { std::vector shape_elements = {ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(PRED, {})}; Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements); @@ -557,12 +555,12 @@ TEST_F(WhileTest, WhileWithPredicateTupleResult) { auto expected_counter = LiteralUtil::CreateR0(5); auto expected_predicate = LiteralUtil::CreateR0(true); - auto expected = LiteralUtil::MakeTuple( - {expected_counter.get(), expected_predicate.get()}); - ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0)); + auto expected = + LiteralUtil::MakeTuple({&expected_counter, &expected_predicate}); + ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0)); } -TEST_F(WhileTest, WhileWithTupleConstantScalarResult) { +XLA_TEST_F(WhileTest, WhileWithTupleConstantScalarResult) { std::vector shape_elements = {ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(S32, {})}; Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements); @@ -602,10 +600,9 @@ TEST_F(WhileTest, WhileWithTupleConstantScalarResult) { auto expected_counter = LiteralUtil::CreateR0(5); auto expected_data = LiteralUtil::CreateR0(7); - auto expected = - LiteralUtil::MakeTuple({expected_counter.get(), expected_data.get()}); - VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape()); - ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001)); + auto expected = LiteralUtil::MakeTuple({&expected_counter, &expected_data}); + VLOG(2) << "expected = " << ShapeUtil::HumanString(expected.shape()); + ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.0001)); } // Tests two while nodes when the result type T is a Tuple and the second @@ -622,7 +619,7 @@ TEST_F(WhileTest, WhileWithTupleConstantScalarResult) { // get<1>(w1) = get<1>(w1) + vector(10, 1.0f); // } // result = get<1>(w0) + get<1>(w1) -TEST_F(WhileTest, TwoWhileWithTupleResult) { +XLA_TEST_F(WhileTest, TwoWhileWithTupleResult) { std::vector shape_elements = {ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(F32, {10})}; Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements); @@ -701,7 +698,7 @@ TEST_F(WhileTest, TwoWhileWithTupleResult) { } // Test while nodes that share the while body computation. -TEST_F(WhileTest, TwoWhileLoopsAndSharedBody) { +XLA_TEST_F(WhileTest, TwoWhileLoopsAndSharedBody) { std::vector shape_elements = {ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(F32, {10})}; Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements); @@ -766,9 +763,7 @@ TEST_F(WhileTest, TwoWhileLoopsAndSharedBody) { ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); } -// Test while nodes that share the while body computation. -// TODO(b/37245345): Fails on GPU backend. -TEST_F(WhileTest, DISABLED_ON_GPU(WhileLoopsWithSharedBodyAndInit)) { +XLA_TEST_F(WhileTest, WhileLoopsWithSharedBodyAndInit) { std::vector shape_elements = {ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(F32, {10})}; Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements); @@ -886,10 +881,9 @@ XLA_TEST_F(WhileTest, WhileWithDynamicUpdateSlice) { auto expected_counter = LiteralUtil::CreateR0(5); auto expected_data = LiteralUtil::CreateR1( {1.0f, 1.0f, 2.0f, 2.0f, 3.0f, 3.0f, 4.0f, 4.0f, 5.0f, 5.0f}); - auto expected = - LiteralUtil::MakeTuple({expected_counter.get(), expected_data.get()}); - VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape()); - ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001)); + auto expected = LiteralUtil::MakeTuple({&expected_counter, &expected_data}); + VLOG(2) << "expected = " << ShapeUtil::HumanString(expected.shape()); + ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.0001)); } // Tests a while node when the result type T is a vector of S32. @@ -907,7 +901,7 @@ XLA_TEST_F(WhileTest, WhileWithDynamicUpdateSlice) { // Per backend the values generated can be different as the different backends // use different random number generators. // TODO(b/32240857): Extend test to verify outputs. -TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileWithPrngScalarResult)) { +XLA_TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileWithPrngScalarResult)) { auto v6s32 = ShapeUtil::MakeShape(S32, {6}); // Create a computation for the condition: repeat for count iterations. @@ -953,7 +947,7 @@ TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileWithPrngScalarResult)) { } } -TEST_F(WhileTest, WhileThatSwapsParameterWithTupleElement) { +XLA_TEST_F(WhileTest, WhileThatSwapsParameterWithTupleElement) { auto element_shape = ShapeUtil::MakeShape(F32, {2}); XlaBuilder outer("outer"); @@ -977,15 +971,15 @@ TEST_F(WhileTest, WhileThatSwapsParameterWithTupleElement) { auto expected_element = LiteralUtil::CreateR1({1, 1}); auto expected = - LiteralUtil::MakeTuple({expected_element.get(), expected_element.get()}); + LiteralUtil::MakeTuple({&expected_element, &expected_element}); TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr parameter_data, - client_->TransferToServer(*LiteralUtil::CreateR1({42, 42}))); - ComputeAndCompareTuple(&outer, *expected, {parameter_data.get()}, + client_->TransferToServer(LiteralUtil::CreateR1({42, 42}))); + ComputeAndCompareTuple(&outer, expected, {parameter_data.get()}, ErrorSpec(1e-6)); } -TEST_F(WhileTest, WhileThatSwapsParameterWithBroadcast) { +XLA_TEST_F(WhileTest, WhileThatSwapsParameterWithBroadcast) { auto element_shape = ShapeUtil::MakeShape(F32, {2}); XlaBuilder outer("outer"); @@ -1005,12 +999,12 @@ TEST_F(WhileTest, WhileThatSwapsParameterWithBroadcast) { TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr parameter_data, - client_->TransferToServer(*LiteralUtil::CreateR1({42, 42}))); + client_->TransferToServer(LiteralUtil::CreateR1({42, 42}))); ComputeAndCompareR1(&outer, {1.0f, 1.0f}, {parameter_data.get()}, ErrorSpec(1e-6)); } -TEST_F(WhileTest, WhileThatTurnsScalarParameterToTupleElement) { +XLA_TEST_F(WhileTest, WhileThatTurnsScalarParameterToTupleElement) { auto element_shape = ShapeUtil::MakeShape(F32, {}); XlaBuilder outer("outer"); @@ -1031,7 +1025,7 @@ TEST_F(WhileTest, WhileThatTurnsScalarParameterToTupleElement) { TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr parameter_data, - client_->TransferToServer(*LiteralUtil::CreateR0(42))); + client_->TransferToServer(LiteralUtil::CreateR0(42))); ComputeAndCompareR0(&outer, 43.0f, {parameter_data.get()}, ErrorSpec(1e-6)); } @@ -1044,7 +1038,7 @@ TEST_F(WhileTest, WhileThatTurnsScalarParameterToTupleElement) { // result[0] = result[0] + 1; // result[1] = result[1] + 1; // } -TEST_F(WhileTest, WhileWithMixedTupleElements) { +XLA_TEST_F(WhileTest, WhileWithMixedTupleElements) { auto result_shape = ShapeUtil::MakeTupleShape( {ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(S32, {})}); @@ -1070,12 +1064,12 @@ TEST_F(WhileTest, WhileWithMixedTupleElements) { TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr parameter_data, - client_->TransferToServer(*LiteralUtil::CreateR0(1))); + client_->TransferToServer(LiteralUtil::CreateR0(1))); auto add1 = LiteralUtil::CreateR0(15); auto add2 = LiteralUtil::CreateR0(16); - auto expected = LiteralUtil::MakeTuple({add1.get(), add2.get()}); - ComputeAndCompareTuple(&outer, *expected, {parameter_data.get()}, + auto expected = LiteralUtil::MakeTuple({&add1, &add2}); + ComputeAndCompareTuple(&outer, expected, {parameter_data.get()}, ErrorSpec(1e-6)); } @@ -1152,7 +1146,7 @@ XLA_TEST_F(WhileTest, NestedWhileWithScalarResult) { // while (f(result).get<0>()) { // result = result + 1; // } -TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileWithCallInsideCondition)) { +XLA_TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileWithCallInsideCondition)) { auto result_shape = ShapeUtil::MakeShape(S32, {}); // Create a computation for the condition: repeat for 5 iterations. @@ -1192,7 +1186,7 @@ TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileWithCallInsideCondition)) { ComputeAndCompareR0(&builder, 5, {}); } -TEST_F(WhileTest, WhileWithLoopInvariantOperation) { +XLA_TEST_F(WhileTest, WhileWithLoopInvariantOperation) { auto matrix_shape = ShapeUtil::MakeShape(F32, {2, 2}); auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); auto while_shape = ShapeUtil::MakeTupleShape( @@ -1228,7 +1222,7 @@ TEST_F(WhileTest, WhileWithLoopInvariantOperation) { GetTupleElement(while_instruction, 3); TF_ASSERT_OK_AND_ASSIGN( - auto param_value, client_->TransferToServer(*LiteralUtil::CreateR2( + auto param_value, client_->TransferToServer(LiteralUtil::CreateR2( {{1.0, 2.0}, {-1.0, -2.0}}))); ComputeAndCompareR2( @@ -1236,7 +1230,7 @@ TEST_F(WhileTest, WhileWithLoopInvariantOperation) { {param_value.get()}, ErrorSpec(4e-5)); } -TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileInfeedCondition)) { +XLA_TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileInfeedCondition)) { auto while_shape = ShapeUtil::MakeShape(S32, {}); XlaComputation condition; @@ -1258,9 +1252,9 @@ TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileInfeedCondition)) { XlaBuilder builder(TestName()); While(condition, body, ConstantR0(&builder, 0)); - TF_ASSERT_OK(client_->TransferToInfeed(*LiteralUtil::CreateR0(true))); - TF_ASSERT_OK(client_->TransferToInfeed(*LiteralUtil::CreateR0(true))); - TF_ASSERT_OK(client_->TransferToInfeed(*LiteralUtil::CreateR0(false))); + TF_ASSERT_OK(client_->TransferToInfeed(LiteralUtil::CreateR0(true))); + TF_ASSERT_OK(client_->TransferToInfeed(LiteralUtil::CreateR0(true))); + TF_ASSERT_OK(client_->TransferToInfeed(LiteralUtil::CreateR0(false))); ComputeAndCompareR0(&builder, 2, {}); } diff --git a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc index 6a7ddd9b55b8ff72a61df5f718f501f02b37302e..a6e70eb6ca25ffac24a8ebaf0420238e109e4fad 100644 --- a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc +++ b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_split.h" @@ -32,7 +33,6 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/platform/regexp.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -83,8 +83,8 @@ struct ParsedProfileOutputLine { Status ParseOneProfileOutputLine( const string& line, bool expect_hlo, - gtl::FlatMap* parsed_results, - tensorflow::gtl::ArraySlice opcodes_to_ignore = {}) { + absl::flat_hash_map* parsed_results, + absl::Span opcodes_to_ignore = {}) { string separator = "[^:]*:: +"; string match_percentage = R"(\d+\.\d*% +\d+Σ)"; string match_cycles = R"((\d+) cycles +\( *()" + match_percentage + R"()\))"; @@ -144,14 +144,14 @@ void ExecuteAndFetchProfile(string* profile_output, LocalClient* client, transfer_manager->AllocateScopedShapedBuffer( lhs_arg_shape, allocator, backend->default_device_ordinal())); TF_ASSERT_OK(transfer_manager->TransferLiteralToDevice( - stream_ptr.get(), *Literal::CreateFromShape(lhs_arg_shape), lhs_arg)); + stream_ptr.get(), Literal::CreateFromShape(lhs_arg_shape), lhs_arg)); TF_ASSERT_OK_AND_ASSIGN( ScopedShapedBuffer rhs_arg, transfer_manager->AllocateScopedShapedBuffer( rhs_arg_shape, allocator, backend->default_device_ordinal())); TF_ASSERT_OK(transfer_manager->TransferLiteralToDevice( - stream_ptr.get(), *Literal::CreateFromShape(rhs_arg_shape), rhs_arg)); + stream_ptr.get(), Literal::CreateFromShape(rhs_arg_shape), rhs_arg)); TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr local_executable, @@ -171,10 +171,10 @@ void ExecuteAndFetchProfile(string* profile_output, LocalClient* client, ServiceExecutableRunOptions run_options( exec_run_options, /*borrow_stream=*/nullptr, backend->eigen_intra_op_thread_pool()); + std::vector args = {&lhs_arg, &rhs_arg}; TF_ASSERT_OK_AND_ASSIGN( auto execution_result, - executable->ExecuteOnStream(&run_options, {&lhs_arg, &rhs_arg}, - &hlo_execution_profile)); + executable->ExecuteOnStream(&run_options, args, &hlo_execution_profile)); TF_ASSERT_OK(stream_ptr->BlockHostUntilDone()); (void)execution_result; @@ -208,7 +208,7 @@ XLA_TEST_F(HloProfileTest, ProfileSingleComputation) { std::vector profile_output_lines = absl::StrSplit(profile_output, '\n'); - gtl::FlatMap parsed_profile_lines; + absl::flat_hash_map parsed_profile_lines; TF_ASSERT_OK(ParseOneProfileOutputLine( profile_output_lines[1], /*expect_hlo=*/false, &parsed_profile_lines)); @@ -314,7 +314,7 @@ XLA_TEST_F(HloProfileTest, ProfileWhileComputation) { ASSERT_NE(while_body_profile_end, profile_output_lines.end()); - gtl::FlatMap parsed_profile_lines; + absl::flat_hash_map parsed_profile_lines; for (auto while_body_profile_i = while_body_profile_start + 1; while_body_profile_i != while_body_profile_end; while_body_profile_i++) { diff --git a/tensorflow/compiler/xla/text_literal_reader.cc b/tensorflow/compiler/xla/text_literal_reader.cc index 9835e3d803a8a873737d9503d588f6caaa749186..cdde88c1359416d423685f330e9cbdf77948034f 100644 --- a/tensorflow/compiler/xla/text_literal_reader.cc +++ b/tensorflow/compiler/xla/text_literal_reader.cc @@ -39,8 +39,7 @@ limitations under the License. namespace xla { -StatusOr> TextLiteralReader::ReadPath( - absl::string_view path) { +StatusOr TextLiteralReader::ReadPath(absl::string_view path) { CHECK(!absl::EndsWith(path, ".gz")) << "TextLiteralReader no longer supports reading .gz files"; std::unique_ptr file; @@ -57,7 +56,7 @@ StatusOr> TextLiteralReader::ReadPath( TextLiteralReader::TextLiteralReader(tensorflow::RandomAccessFile* file) : file_(file) {} -StatusOr> TextLiteralReader::ReadAllLines() { +StatusOr TextLiteralReader::ReadAllLines() { tensorflow::io::RandomAccessInputStream stream(file_.get()); tensorflow::io::BufferedInputStream buf(&stream, 65536); string shape_string; @@ -71,12 +70,12 @@ StatusOr> TextLiteralReader::ReadAllLines() { if (shape.element_type() != F32) { return Unimplemented( "unsupported element type for text literal reading: %s", - ShapeUtil::HumanString(shape).c_str()); + ShapeUtil::HumanString(shape)); } - auto result = absl::make_unique(shape); + Literal result(shape); const float fill = std::numeric_limits::quiet_NaN(); - result->PopulateWithValue(fill); + result.PopulateWithValue(fill); std::vector pieces; std::vector coordinates; std::vector coordinate_values; @@ -88,16 +87,16 @@ StatusOr> TextLiteralReader::ReadAllLines() { absl::string_view value_string = absl::StripAsciiWhitespace(pieces[1]); if (!absl::ConsumePrefix(&coordinates_string, "(")) { return InvalidArgument( - "expected '(' at the beginning of coordinates: \"%s\"", line.c_str()); + "expected '(' at the beginning of coordinates: \"%s\"", line); } if (!absl::ConsumeSuffix(&coordinates_string, ")")) { return InvalidArgument("expected ')' at the end of coordinates: \"%s\"", - line.c_str()); + line); } float value; - if (!absl::SimpleAtof(absl::string_view(value_string), &value)) { + if (!absl::SimpleAtof(value_string, &value)) { return InvalidArgument("could not parse value as float: \"%s\"", - string(value_string).c_str()); + value_string); } coordinates = absl::StrSplit(coordinates_string, ','); coordinate_values.clear(); @@ -106,17 +105,17 @@ StatusOr> TextLiteralReader::ReadAllLines() { if (!absl::SimpleAtoi(piece, &coordinate_value)) { return InvalidArgument( "could not parse coordinate member as int64: \"%s\"", - std::string(piece).c_str()); + std::string(piece)); } coordinate_values.push_back(coordinate_value); } if (coordinate_values.size() != shape.dimensions_size()) { return InvalidArgument( - "line did not have expected number of coordinates; want %d got %zu: " + "line did not have expected number of coordinates; want %d got %u: " "\"%s\"", - shape.dimensions_size(), coordinate_values.size(), line.c_str()); + shape.dimensions_size(), coordinate_values.size(), line); } - result->Set(coordinate_values, value); + result.Set(coordinate_values, value); } return std::move(result); } diff --git a/tensorflow/compiler/xla/text_literal_reader.h b/tensorflow/compiler/xla/text_literal_reader.h index b265640802c88847ce57e9f942f9f0859b873ae8..c40b43279f56fbd6e8ec91cc45c1f8e4cac8b5ef 100644 --- a/tensorflow/compiler/xla/text_literal_reader.h +++ b/tensorflow/compiler/xla/text_literal_reader.h @@ -41,7 +41,7 @@ class TextLiteralReader { public: // See class comment -- reads a file in its entirety (there must be only one // literal in the text file path provided). - static StatusOr> ReadPath(absl::string_view path); + static StatusOr ReadPath(absl::string_view path); private: // Ownership of file is transferred. @@ -49,7 +49,7 @@ class TextLiteralReader { // Parses a shape string on the first line, followed by lines of values to the // end of the file. - StatusOr> ReadAllLines(); + StatusOr ReadAllLines(); // Owns the file being read std::unique_ptr file_; diff --git a/tensorflow/compiler/xla/text_literal_reader_test.cc b/tensorflow/compiler/xla/text_literal_reader_test.cc index 92f9b4f9f0efa2dc08287bdcbefc88f879164308..1fab4e3a08dd3d76a6efeaabe7bf8ab96892e638 100644 --- a/tensorflow/compiler/xla/text_literal_reader_test.cc +++ b/tensorflow/compiler/xla/text_literal_reader_test.cc @@ -42,16 +42,15 @@ TEST(TextLiteralReaderTest, ReadsR3File) { tensorflow::WriteStringToFile(tensorflow::Env::Default(), fname, contents) .ok()); - std::unique_ptr literal = - TextLiteralReader::ReadPath(fname).ConsumeValueOrDie(); + Literal literal = TextLiteralReader::ReadPath(fname).ConsumeValueOrDie(); EXPECT_TRUE( - ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {1, 2, 3}), literal->shape())); - EXPECT_EQ(42.5, literal->Get({0, 0, 0})); - EXPECT_EQ(43.5, literal->Get({0, 0, 1})); - EXPECT_EQ(44.5, literal->Get({0, 0, 2})); - EXPECT_EQ(45.5, literal->Get({0, 1, 0})); - EXPECT_EQ(46.5, literal->Get({0, 1, 1})); - EXPECT_EQ(47.5, literal->Get({0, 1, 2})); + ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {1, 2, 3}), literal.shape())); + EXPECT_EQ(42.5, literal.Get({0, 0, 0})); + EXPECT_EQ(43.5, literal.Get({0, 0, 1})); + EXPECT_EQ(44.5, literal.Get({0, 0, 2})); + EXPECT_EQ(45.5, literal.Get({0, 1, 0})); + EXPECT_EQ(46.5, literal.Get({0, 1, 1})); + EXPECT_EQ(47.5, literal.Get({0, 1, 2})); } } // namespace diff --git a/tensorflow/compiler/xla/text_literal_writer.cc b/tensorflow/compiler/xla/text_literal_writer.cc index 00147015a6b2bf41205a81dddd0b16f5ab434130..7289ae7df65e56652eeeb67e536e4c721d97d999 100644 --- a/tensorflow/compiler/xla/text_literal_writer.cc +++ b/tensorflow/compiler/xla/text_literal_writer.cc @@ -19,12 +19,12 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/types.h" @@ -46,8 +46,7 @@ namespace xla { Status status; tensorflow::WritableFile* f_ptr = f.get(); literal.EachCellAsString( - [f_ptr, &status](tensorflow::gtl::ArraySlice indices, - const string& value) { + [f_ptr, &status](absl::Span indices, const string& value) { if (!status.ok()) { return; } diff --git a/tensorflow/compiler/xla/text_literal_writer_test.cc b/tensorflow/compiler/xla/text_literal_writer_test.cc index 4ea02faffcd52065b05c0444202bd1a3d9d87ee6..5cbaf2fcc192c48092272094710ccaf5c9cf9616 100644 --- a/tensorflow/compiler/xla/text_literal_writer_test.cc +++ b/tensorflow/compiler/xla/text_literal_writer_test.cc @@ -37,7 +37,7 @@ TEST(TextLiteralWriterTest, WritesFloatLiteral) { }); string path = tensorflow::io::JoinPath(tensorflow::testing::TmpDir(), "/whatever"); - ASSERT_IS_OK(TextLiteralWriter::WriteToPath(*literal, path)); + ASSERT_IS_OK(TextLiteralWriter::WriteToPath(literal, path)); string contents; TF_CHECK_OK(tensorflow::ReadFileToString(tensorflow::Env::Default(), path, &contents)); diff --git a/tensorflow/compiler/xla/tools/BUILD b/tensorflow/compiler/xla/tools/BUILD index 1e4558814851b238f78f781ab4b6b6bd7608f752..3a086c66bbb37965b1ad7c83a93f0054ae723e87 100644 --- a/tensorflow/compiler/xla/tools/BUILD +++ b/tensorflow/compiler/xla/tools/BUILD @@ -24,6 +24,7 @@ tf_cc_binary( "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "@com_google_absl//absl/base", "@com_google_absl//absl/strings", ], ) @@ -43,6 +44,7 @@ cc_library( "//tensorflow/compiler/xla/service", "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/types:span", ], ) @@ -68,6 +70,7 @@ tf_cc_binary( "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/compiler/xla/service:interpreter_plugin", "//tensorflow/core:lib", + "@com_google_absl//absl/types:span", ], ) @@ -95,6 +98,7 @@ cc_library( "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", + "@com_google_absl//absl/types:span", ], alwayslink = True, ) @@ -173,6 +177,7 @@ tf_cc_binary( "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/compiler/xla/service:interpreter_plugin", "//tensorflow/core:lib", + "@com_google_absl//absl/types:span", ], ) @@ -193,6 +198,8 @@ tf_cc_binary( "//tensorflow/compiler/xla/service:interpreter_plugin", "//tensorflow/core:lib", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", ], ) @@ -212,6 +219,7 @@ tf_cc_binary( "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/compiler/xla/service:interpreter_plugin", "//tensorflow/core:lib", + "@com_google_absl//absl/types:span", ], ) diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc index f20dcef382b86d27d7c176ae7e4132ad1db7b901..c866a13de7543fc948311f94708bc6b904717b62 100644 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc @@ -28,6 +28,7 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" @@ -38,7 +39,6 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/logging.h" @@ -46,7 +46,7 @@ limitations under the License. namespace xla { namespace tools { -void RealMain(tensorflow::gtl::ArraySlice args) { +void RealMain(absl::Span args) { Client* client = ClientLibrary::LocalClientOrDie(); for (char* arg : args) { HloSnapshot module; @@ -77,8 +77,8 @@ int main(int argc, char** argv) { } tensorflow::port::InitMain(argv[0], &argc, &argv); - tensorflow::gtl::ArraySlice args(argv, argc); - args.pop_front(); // Pop off the binary name, argv[0] + absl::Span args(argv, argc); + args.remove_prefix(1); // Pop off the binary name, argv[0] xla::tools::RealMain(args); return 0; } diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc index 7aedd1da980d946399c0b8066d046f941d70143e..4375e7c138c9e8d193feaa7a39d63946c4ea3086 100644 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc @@ -19,7 +19,9 @@ limitations under the License. #include #include +#include "absl/strings/str_format.h" #include "absl/strings/str_join.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" @@ -30,8 +32,6 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/logging.h" @@ -49,10 +49,9 @@ class OperationDumper : public DfsHloVisitorWithDefault { absl::StrAppend(out, ShapeUtil::HumanString(operand->shape())); }); // Spit `op_name(params...) -> result_type :: path` to stdout. - std::cout << tensorflow::strings::Printf( - "%s :: (%s) -> %s :: %s\n", HloOpcodeString(hlo->opcode()).c_str(), - params.c_str(), ShapeUtil::HumanString(hlo->shape()).c_str(), - path_.c_str()); + std::cout << absl::StrFormat("%s :: (%s) -> %s :: %s\n", + HloOpcodeString(hlo->opcode()), params, + ShapeUtil::HumanString(hlo->shape()), path_); return Status::OK(); } @@ -60,7 +59,7 @@ class OperationDumper : public DfsHloVisitorWithDefault { string path_; }; -void RealMain(tensorflow::gtl::ArraySlice args) { +void RealMain(absl::Span args) { LocalClient* client = ClientLibrary::LocalClientOrDie(); LocalService* local_service = ClientLibrary::GetXlaService(client->platform()); @@ -105,8 +104,8 @@ void RealMain(tensorflow::gtl::ArraySlice args) { int main(int argc, char** argv) { tensorflow::port::InitMain(argv[0], &argc, &argv); - tensorflow::gtl::ArraySlice args(argv, argc); - args.pop_front(); // Pop off the binary name, argv[0] + absl::Span args(argv, argc); + args.remove_prefix(1); // Pop off the binary name, argv[0] xla::tools::RealMain(args); return 0; } diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc index f03e1b1f965af761c101555fd0275bc0425b9cf0..723569862c7550387e95003e3a673743464b67b8 100644 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" @@ -26,7 +27,6 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/logging.h" @@ -34,7 +34,7 @@ limitations under the License. namespace xla { namespace tools { -void RealMain(tensorflow::gtl::ArraySlice args, bool compile) { +void RealMain(absl::Span args, bool compile) { LocalClient* client = ClientLibrary::LocalClientOrDie(); LocalService* local_service = ClientLibrary::GetXlaService(client->platform()); @@ -102,8 +102,8 @@ int main(int argc, char** argv) { tensorflow::port::InitMain(usage.c_str(), &argc, &argv); QCHECK(argc > 1) << "\nERROR: must specify at least one module\n" << usage; - tensorflow::gtl::ArraySlice args(argv, argc); - args.pop_front(); // Pop off the binary name, argv[0] + absl::Span args(argv, argc); + args.remove_prefix(1); // Pop off the binary name, argv[0] xla::tools::RealMain(args, compile); return 0; } diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc index dc5c106d02cb679f3e6f5b2bea40bbb42f8bd1cc..07ef5ff656bb48519a700a1d7d6c60b655a40ed6 100644 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc @@ -26,6 +26,7 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" @@ -35,7 +36,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/service.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/logging.h" @@ -45,7 +45,7 @@ using tensorflow::Env; namespace xla { namespace tools { -void RealMain(tensorflow::gtl::ArraySlice args) { +void RealMain(absl::Span args) { Client* client = ClientLibrary::LocalClientOrDie(); for (char* arg : args) { HloSnapshot module; @@ -78,8 +78,8 @@ int main(int argc, char** argv) { tensorflow::port::InitMain(argv[0], &argc, &argv); - tensorflow::gtl::ArraySlice args(argv, argc); - args.pop_front(); // Pop off the binary name, argv[0] + absl::Span args(argv, argc); + args.remove_prefix(1); // Pop off the binary name, argv[0] xla::tools::RealMain(args); return 0; } diff --git a/tensorflow/compiler/xla/tools/hex_floats_to_packed_literal.cc b/tensorflow/compiler/xla/tools/hex_floats_to_packed_literal.cc index 75b63c3b84c21005f64b770c44219d92ffce99df..0c3ec5934e546f551089f715dbbe6f4479e56c3c 100644 --- a/tensorflow/compiler/xla/tools/hex_floats_to_packed_literal.cc +++ b/tensorflow/compiler/xla/tools/hex_floats_to_packed_literal.cc @@ -17,9 +17,9 @@ limitations under the License. #include #include +#include "absl/base/casts.h" #include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/core/casts.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/io/buffered_inputstream.h" #include "tensorflow/core/lib/io/random_inputstream.h" @@ -67,9 +67,8 @@ int main(int argc, char** argv) { floats.push_back(value); } - tensorflow::StringPiece content( // non-absl ok - tensorflow::bit_cast(floats.data()), - floats.size() * sizeof(float)); + tensorflow::StringPiece content(absl::bit_cast(floats.data()), + floats.size() * sizeof(float)); TF_CHECK_OK(tensorflow::WriteStringToFile(tensorflow::Env::Default(), output_file, content)); return 0; diff --git a/tensorflow/compiler/xla/tools/replay_computation.cc b/tensorflow/compiler/xla/tools/replay_computation.cc index 311a1bee8daa3a5d126f00dcabe0675f791adeaa..f910e980535c073562473978662f73f4ee4bee79 100644 --- a/tensorflow/compiler/xla/tools/replay_computation.cc +++ b/tensorflow/compiler/xla/tools/replay_computation.cc @@ -40,6 +40,7 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/global_data.h" @@ -59,7 +60,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/threadpool.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/logging.h" @@ -83,7 +83,8 @@ std::unique_ptr CompileExecutable(const HloSnapshot& module, LocalClient* client) { XlaComputation computation(module.hlo().hlo_module()); std::vector argument_layouts; - for (const auto& param : computation.proto().program_shape().parameters()) { + for (const auto& param : + computation.proto().host_program_shape().parameters()) { argument_layouts.push_back(¶m); } return client @@ -121,11 +122,10 @@ StatusOr ReplayComputation(const HloSnapshot& module, } } else { // use recorded data if available for (const auto& proto : module.arguments()) { - TF_ASSIGN_OR_RETURN(std::unique_ptr literal, - Literal::CreateFromProto(proto)); + TF_ASSIGN_OR_RETURN(Literal literal, Literal::CreateFromProto(proto)); TF_ASSIGN_OR_RETURN( ScopedShapedBuffer data, - client->LiteralToShapedBuffer(*literal, /*device_ordinal=*/0)); + client->LiteralToShapedBuffer(literal, /*device_ordinal=*/0)); scoped_shaped_buffer_arguments.push_back(std::move(data)); } for (const auto& argument : scoped_shaped_buffer_arguments) { @@ -161,12 +161,12 @@ StatusOr ReplayComputation(const HloSnapshot& module, // --generate_fake_infeed is passed and there exists an infeed operation in // the HloSnapshot. absl::optional pool; - std::unique_ptr data; + Literal data; if (provide_infeed) { data = std::move(MakeFakeLiteral(infeed_shape)).ValueOrDie(); } auto transfer_infeed = [&data, client]() { - TF_CHECK_OK(client->TransferToInfeed(*data)); + TF_CHECK_OK(client->TransferToInfeed(data)); }; if (provide_infeed) { pool.emplace(tensorflow::Env::Default(), "infeed", @@ -214,9 +214,9 @@ StatusOr ReplayComputation(const HloSnapshot& module, << "s: " << module.hlo().hlo_module().name(); } - TF_ASSIGN_OR_RETURN(std::unique_ptr result_literal, + TF_ASSIGN_OR_RETURN(Literal result_literal, client->ShapedBufferToLiteral(*result)); - return std::move(*result_literal); + return result_literal; } StatusOr ParseInputFile(const string& filename, @@ -250,10 +250,10 @@ StatusOr ParseInputFile(const string& filename, } fprintf(stderr, "%s: is not HLO text. Nothing left to try.\n", filename.c_str()); - return InvalidArgument("Could not parse %s.", filename.c_str()); + return InvalidArgument("Could not parse %s.", filename); } -int RealMain(tensorflow::gtl::ArraySlice args, const Options& opts) { +int RealMain(absl::Span args, const Options& opts) { LocalClient* client = ClientLibrary::LocalClientOrDie(); int exit_status = EXIT_SUCCESS; @@ -305,11 +305,11 @@ int RealMain(tensorflow::gtl::ArraySlice args, const Options& opts) { result.ToString().c_str()); auto& snapshot = snapshots[i]; if (snapshot.has_result()) { - std::unique_ptr literal = + Literal literal = Literal::CreateFromProto(snapshot.result()).ConsumeValueOrDie(); fprintf(stdout, "was %s:%s\n", ShapeUtil::HumanString(snapshot.result().shape()).c_str(), - literal->ToString().c_str()); + literal.ToString().c_str()); } } } @@ -344,7 +344,7 @@ int main(int argc, char** argv) { LOG(QFATAL) << usage; } - tensorflow::gtl::ArraySlice args(argv, argc); - args.pop_front(); // Pop off the binary name, argv[0] + absl::Span args(argv, argc); + args.remove_prefix(1); // Pop off the binary name, argv[0] return xla::tools::RealMain(args, opts); } diff --git a/tensorflow/compiler/xla/tools/show_literal.cc b/tensorflow/compiler/xla/tools/show_literal.cc index 51909190a3ef20c3df78d08796e88bdbb650609d..4f8852f8c11fb749ef851bc4faf176fcc5cb3524 100644 --- a/tensorflow/compiler/xla/tools/show_literal.cc +++ b/tensorflow/compiler/xla/tools/show_literal.cc @@ -40,8 +40,8 @@ int main(int argc, char **argv) { xla::LiteralProto literal_proto; TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), argv[1], &literal_proto)); - std::unique_ptr literal = + xla::Literal literal = xla::Literal::CreateFromProto(literal_proto).ConsumeValueOrDie(); LOG(INFO) << "literal: " << literal_proto.ShortDebugString(); - fprintf(stderr, "%s\n", literal->ToString().c_str()); + fprintf(stderr, "%s\n", literal.ToString().c_str()); } diff --git a/tensorflow/compiler/xla/tools/show_signature.cc b/tensorflow/compiler/xla/tools/show_signature.cc index 4e53fafcc97ff53afc5713e7ed8ee5222fac316b..cdf306dfd1027cf6022c5d8ae844b4308f580e8d 100644 --- a/tensorflow/compiler/xla/tools/show_signature.cc +++ b/tensorflow/compiler/xla/tools/show_signature.cc @@ -29,6 +29,7 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" @@ -37,7 +38,6 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/logging.h" @@ -45,7 +45,7 @@ limitations under the License. namespace xla { namespace tools { -void RealMain(tensorflow::gtl::ArraySlice args) { +void RealMain(absl::Span args) { Client* client = ClientLibrary::LocalClientOrDie(); for (char* arg : args) { HloSnapshot module; @@ -66,8 +66,8 @@ void RealMain(tensorflow::gtl::ArraySlice args) { int main(int argc, char** argv) { tensorflow::port::InitMain(argv[0], &argc, &argv); - tensorflow::gtl::ArraySlice args(argv, argc); - args.pop_front(); // Pop off the binary name, argv[0] + absl::Span args(argv, argc); + args.remove_prefix(1); // Pop off the binary name, argv[0] xla::tools::RealMain(args); return 0; } diff --git a/tensorflow/compiler/xla/tools/show_text_literal.cc b/tensorflow/compiler/xla/tools/show_text_literal.cc index 48c837481181f6ad8f864569fd62e0e23fa02ecd..4b5c276bdf66f3dc5364aae4654b13a625b0a4f7 100644 --- a/tensorflow/compiler/xla/tools/show_text_literal.cc +++ b/tensorflow/compiler/xla/tools/show_text_literal.cc @@ -36,16 +36,16 @@ int main(int argc, char **argv) { LOG(QFATAL) << "Usage: " << argv[0] << " "; } - std::unique_ptr literal = + xla::Literal literal = xla::TextLiteralReader::ReadPath(argv[1]).ConsumeValueOrDie(); - LOG(INFO) << "literal: " << *literal; - fprintf(stderr, "%s\n", literal->ToString().c_str()); - if (literal->shape().element_type() == xla::F32) { - float min = *std::min_element(literal->data().begin(), - literal->data().end()); - float max = *std::max_element(literal->data().begin(), - literal->data().end()); + LOG(INFO) << "literal: " << literal; + fprintf(stderr, "%s\n", literal.ToString().c_str()); + if (literal.shape().element_type() == xla::F32) { + float min = *std::min_element(literal.data().begin(), + literal.data().end()); + float max = *std::max_element(literal.data().begin(), + literal.data().end()); fprintf(stderr, "min: %a=%f\n", min, min); fprintf(stderr, "max: %a=%f\n", max, max); } diff --git a/tensorflow/compiler/xla/util.cc b/tensorflow/compiler/xla/util.cc index 85f05b7b8d786236ff2fe62cde6a721f5c8c09ea..68cab7387cf1576072f96878b50f07def6862d8b 100644 --- a/tensorflow/compiler/xla/util.cc +++ b/tensorflow/compiler/xla/util.cc @@ -25,7 +25,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/strings/numbers.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/stacktrace.h" @@ -68,86 +67,6 @@ Status AppendStatus(Status prior, absl::string_view context) { absl::StrCat(prior.error_message(), ": ", context)}; } -// Implementation note: we can't common these out (without using macros) because -// they all need to va_start/va_end their varargs in their frame. - -Status InvalidArgumentV(const char* format, va_list args) { - string message; - tensorflow::strings::Appendv(&message, format, args); - return WithLogBacktrace(tensorflow::errors::InvalidArgument(message)); -} - -Status InvalidArgument(const char* format, ...) { - va_list args; - va_start(args, format); - Status result = InvalidArgumentV(format, args); - va_end(args); - return result; -} - -Status Unimplemented(const char* format, ...) { - string message; - va_list args; - va_start(args, format); - tensorflow::strings::Appendv(&message, format, args); - va_end(args); - return WithLogBacktrace(tensorflow::errors::Unimplemented(message)); -} - -Status InternalError(const char* format, ...) { - string message; - va_list args; - va_start(args, format); - tensorflow::strings::Appendv(&message, format, args); - va_end(args); - return WithLogBacktrace(tensorflow::errors::Internal(message)); -} - -Status FailedPrecondition(const char* format, ...) { - string message; - va_list args; - va_start(args, format); - tensorflow::strings::Appendv(&message, format, args); - va_end(args); - return WithLogBacktrace(tensorflow::errors::FailedPrecondition(message)); -} - -Status Cancelled(const char* format, ...) { - string message; - va_list args; - va_start(args, format); - tensorflow::strings::Appendv(&message, format, args); - va_end(args); - return WithLogBacktrace(tensorflow::errors::Cancelled(message)); -} - -Status ResourceExhausted(const char* format, ...) { - string message; - va_list args; - va_start(args, format); - tensorflow::strings::Appendv(&message, format, args); - va_end(args); - return WithLogBacktrace(tensorflow::errors::ResourceExhausted(message)); -} - -Status NotFound(const char* format, ...) { - string message; - va_list args; - va_start(args, format); - tensorflow::strings::Appendv(&message, format, args); - va_end(args); - return WithLogBacktrace(tensorflow::errors::NotFound(message)); -} - -Status Unavailable(const char* format, ...) { - string message; - va_list args; - va_start(args, format); - tensorflow::strings::Appendv(&message, format, args); - va_end(args); - return WithLogBacktrace(tensorflow::errors::Unavailable(message)); -} - string Reindent(absl::string_view original, const absl::string_view indentation) { std::vector pieces = @@ -157,7 +76,7 @@ string Reindent(absl::string_view original, }); } -bool IsPermutation(tensorflow::gtl::ArraySlice permutation, int64 rank) { +bool IsPermutation(absl::Span permutation, int64 rank) { if (rank != permutation.size()) { return false; } @@ -171,7 +90,7 @@ bool IsPermutation(tensorflow::gtl::ArraySlice permutation, int64 rank) { } std::vector InversePermutation( - tensorflow::gtl::ArraySlice input_permutation) { + absl::Span input_permutation) { DCHECK(IsPermutation(input_permutation, input_permutation.size())); std::vector output_permutation(input_permutation.size(), -1); for (size_t i = 0; i < input_permutation.size(); ++i) { @@ -180,8 +99,8 @@ std::vector InversePermutation( return output_permutation; } -std::vector ComposePermutations(tensorflow::gtl::ArraySlice p1, - tensorflow::gtl::ArraySlice p2) { +std::vector ComposePermutations(absl::Span p1, + absl::Span p2) { CHECK_EQ(p1.size(), p2.size()); std::vector output; for (size_t i = 0; i < p1.size(); ++i) { @@ -190,7 +109,7 @@ std::vector ComposePermutations(tensorflow::gtl::ArraySlice p1, return output; } -bool IsIdentityPermutation(tensorflow::gtl::ArraySlice permutation) { +bool IsIdentityPermutation(absl::Span permutation) { for (int64 i = 0; i < permutation.size(); ++i) { if (permutation[i] != i) { return false; @@ -211,7 +130,7 @@ PaddingConfig MakeNoPaddingConfig(int64 rank) { } PaddingConfig MakeEdgePaddingConfig( - tensorflow::gtl::ArraySlice> padding) { + absl::Span> padding) { PaddingConfig padding_config; for (const std::pair& dim : padding) { auto dimension = padding_config.add_dimensions(); @@ -288,14 +207,13 @@ void LogLines(int sev, absl::string_view text, const char* fname, int lineno) { } } -int64 Product(tensorflow::gtl::ArraySlice xs) { +int64 Product(absl::Span xs) { return std::accumulate(xs.begin(), xs.end(), static_cast(1), std::multiplies()); } -std::vector> CommonFactors( - tensorflow::gtl::ArraySlice a, - tensorflow::gtl::ArraySlice b) { +std::vector> CommonFactors(absl::Span a, + absl::Span b) { CHECK_EQ(Product(a), Product(b)); if (0 == Product(a)) { return {std::make_pair(0, 0), std::make_pair(a.size(), b.size())}; diff --git a/tensorflow/compiler/xla/util.h b/tensorflow/compiler/xla/util.h index 671ef17f36518df0b5e60e9b0c0c76d2e0358e00..8ce741647414a1fa75e6d706ec1e719ace7b7cc8 100644 --- a/tensorflow/compiler/xla/util.h +++ b/tensorflow/compiler/xla/util.h @@ -27,13 +27,15 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/container/inlined_vector.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/math/math_util.h" #include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/platform/logging.h" @@ -99,65 +101,63 @@ struct ScopedLoggingTimer { uint64 start_micros; }; -// Given a vector, returns a MutableArraySlice that points at its +// Given a vector, returns a Span that points at its // internals. // // Warning: if the vector is updated its storage pointer may change, so use this // with caution (ideally in limited scopes with temporary lifetimes). template -tensorflow::gtl::MutableArraySlice MutableByteSlice(std::vector* v) { - return tensorflow::gtl::MutableArraySlice( - reinterpret_cast(v->data()), v->size() * sizeof(T)); +absl::Span MutableByteSlice(std::vector* v) { + return absl::Span(reinterpret_cast(v->data()), + v->size() * sizeof(T)); } // Turns an immutable slice of type T into an immutable slice of bytes with the // same byte size. template -tensorflow::gtl::ArraySlice CastToByteSlice( - tensorflow::gtl::ArraySlice slice) { - return tensorflow::gtl::ArraySlice( - reinterpret_cast(slice.data()), slice.size() * sizeof(T)); +absl::Span CastToByteSlice(absl::Span slice) { + return absl::Span(reinterpret_cast(slice.data()), + slice.size() * sizeof(T)); } // Casts a byte slice to a non-byte type T, checking that the original slice // length is a multiple of sizeof(T). template -tensorflow::gtl::ArraySlice CastByteSlice( - tensorflow::gtl::ArraySlice slice) { +absl::Span CastByteSlice(absl::Span slice) { CHECK_EQ(0, slice.size() % sizeof(T)); - return tensorflow::gtl::ArraySlice( - reinterpret_cast(slice.data()), slice.size() / sizeof(T)); + return absl::Span(reinterpret_cast(slice.data()), + slice.size() / sizeof(T)); } // Convenience function to force a vector to convert to an immutable slice. template -tensorflow::gtl::ArraySlice AsSlice(const std::vector& v) { - return tensorflow::gtl::ArraySlice(v); +absl::Span AsSlice(const std::vector& v) { + return absl::Span(v); } -// Converts a mutable vector pointer into a MutableArraySlice of the same +// Converts a mutable vector pointer into a Span of the same // type. template -tensorflow::gtl::MutableArraySlice AsMutableSlice(std::vector* v) { - return tensorflow::gtl::MutableArraySlice(v->data(), v->size()); +absl::Span AsMutableSlice(std::vector* v) { + return absl::Span(v->data(), v->size()); } // xla::int64 is not the same type as tensorflow::protobuf_int64 in open-source. // Wrapper function that gives an int64 array slice view of a repeated int64 // protobuf field. -static inline tensorflow::gtl::ArraySlice AsInt64Slice( +static inline absl::Span AsInt64Slice( const tensorflow::protobuf::RepeatedField& v) { - tensorflow::gtl::ArraySlice slice(v); - return tensorflow::gtl::ArraySlice( - reinterpret_cast(slice.data()), slice.size()); + absl::Span slice(v); + return absl::Span(reinterpret_cast(slice.data()), + slice.size()); } // As above, but for uint64 types. -static inline tensorflow::gtl::ArraySlice AsUInt64Slice( +static inline absl::Span AsUInt64Slice( const tensorflow::protobuf::RepeatedField& v) { - tensorflow::gtl::ArraySlice slice(v); - return tensorflow::gtl::ArraySlice( - reinterpret_cast(slice.data()), slice.size()); + absl::Span slice(v); + return absl::Span(reinterpret_cast(slice.data()), + slice.size()); } // Compares two containers for equality. Returns true iff the two containers @@ -173,7 +173,7 @@ template bool ContainersEqual(const Container1T& c1, std::initializer_list il) { - tensorflow::gtl::ArraySlice c2{il}; + absl::Span c2{il}; return ContainersEqual(c1, c2); } @@ -191,9 +191,9 @@ bool ContainersEqual(const Container1T& c1, const Container2T& c2, // source and destination. The source starting index is src_base, while the // destination one is dest_base. template -void StridedCopy(tensorflow::gtl::MutableArraySlice dest, int64 dest_base, - int64 dest_stride, tensorflow::gtl::ArraySlice src, - int64 src_base, int64 src_stride, int64 count) { +void StridedCopy(absl::Span dest, int64 dest_base, int64 dest_stride, + absl::Span src, int64 src_base, int64 src_stride, + int64 count) { for (; count > 0; --count, dest_base += dest_stride, src_base += src_stride) { dest[dest_base] = static_cast(src[src_base]); } @@ -205,43 +205,73 @@ void StridedCopy(tensorflow::gtl::MutableArraySlice dest, int64 dest_base, Status AddStatus(Status prior, absl::string_view context); Status AppendStatus(Status prior, absl::string_view context); -// Status error shorthands -- printfs the arguments to be -// used as an error message and returns a status in the canonical -// error space. -Status InvalidArgument(const char* format, ...) TF_PRINTF_ATTRIBUTE(1, 2); -Status Unimplemented(const char* format, ...) TF_PRINTF_ATTRIBUTE(1, 2); -Status InternalError(const char* format, ...) TF_PRINTF_ATTRIBUTE(1, 2); -Status FailedPrecondition(const char* format, ...) TF_PRINTF_ATTRIBUTE(1, 2); -Status Cancelled(const char* format, ...) TF_PRINTF_ATTRIBUTE(1, 2); -Status ResourceExhausted(const char* format, ...) TF_PRINTF_ATTRIBUTE(1, 2); -Status NotFound(const char* format, ...) TF_PRINTF_ATTRIBUTE(1, 2); -Status Unavailable(const char* format, ...) TF_PRINTF_ATTRIBUTE(1, 2); - -// Passed-varargs variant of the InvalidArgument factory above. -Status InvalidArgumentV(const char* format, va_list args); +// Status error shorthands -- StrFormat's the arguments to be used as an error +// message and returns a status in the canonical error space. +template +Status InvalidArgument(const absl::FormatSpec& format, + const Args&... args) { + return WithLogBacktrace( + tensorflow::errors::InvalidArgument(absl::StrFormat(format, args...))); +} +template +Status Unimplemented(const absl::FormatSpec& format, + const Args&... args) { + return WithLogBacktrace( + tensorflow::errors::Unimplemented(absl::StrFormat(format, args...))); +} +template +Status InternalError(const absl::FormatSpec& format, + const Args&... args) { + return WithLogBacktrace( + tensorflow::errors::Internal(absl::StrFormat(format, args...))); +} +template +Status FailedPrecondition(const absl::FormatSpec& format, + const Args&... args) { + return WithLogBacktrace( + tensorflow::errors::FailedPrecondition(absl::StrFormat(format, args...))); +} +template +Status Cancelled(const absl::FormatSpec& format, const Args&... args) { + return WithLogBacktrace( + tensorflow::errors::Cancelled(absl::StrFormat(format, args...))); +} +template +Status ResourceExhausted(const absl::FormatSpec& format, + const Args&... args) { + return WithLogBacktrace( + tensorflow::errors::ResourceExhausted(absl::StrFormat(format, args...))); +} +template +Status NotFound(const absl::FormatSpec& format, const Args&... args) { + return WithLogBacktrace( + tensorflow::errors::NotFound(absl::StrFormat(format, args...))); +} +template +Status Unavailable(const absl::FormatSpec& format, + const Args&... args) { + return WithLogBacktrace( + tensorflow::errors::Unavailable(absl::StrFormat(format, args...))); +} template Status InvalidArgumentStrCat(Args&&... concat) { - return InvalidArgument("%s", - absl::StrCat(std::forward(concat)...).c_str()); + return InvalidArgument("%s", absl::StrCat(std::forward(concat)...)); } template Status UnimplementedStrCat(Args&&... concat) { - return Unimplemented("%s", - absl::StrCat(std::forward(concat)...).c_str()); + return Unimplemented("%s", absl::StrCat(std::forward(concat)...)); } template Status InternalErrorStrCat(Args&&... concat) { - return InternalError("%s", - absl::StrCat(std::forward(concat)...).c_str()); + return InternalError("%s", absl::StrCat(std::forward(concat)...)); } template Status ResourceExhaustedStrCat(Args&&... concat) { - return ResourceExhausted("%s", - absl::StrCat(std::forward(concat)...).c_str()); + return ResourceExhausted("%s", absl::StrCat(std::forward(concat)...)); } // Splits the lines of the original, replaces leading whitespace with the prefix @@ -253,7 +283,7 @@ Status ResourceExhaustedStrCat(Args&&... concat) { string Reindent(absl::string_view original, absl::string_view indentation); // Checks whether permutation is a permutation of the [0, rank) integer range. -bool IsPermutation(tensorflow::gtl::ArraySlice permutation, int64 rank); +bool IsPermutation(absl::Span permutation, int64 rank); // Applies `permutation` on `input` and returns the permuted array. // For each i, output[permutation[i]] = input[i]. @@ -261,10 +291,11 @@ bool IsPermutation(tensorflow::gtl::ArraySlice permutation, int64 rank); // Precondition: // 1. `permutation` is a permutation of 0..permutation.size()-1. // 2. permutation.size() == input.size(). -template